diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py index cc927eec9..60198194a 100644 --- a/routes/chat_helpers.py +++ b/routes/chat_helpers.py @@ -159,17 +159,9 @@ async def auto_name_session(session_manager, sess): return owner = getattr(sess, "owner", None) - t_url, t_model, t_headers = resolve_task_endpoint(owner=owner) - if not t_model: - # If no task/utility model is configured at all, fall back to - # the session's own model so auto-naming still works even on - # minimal setups. - from src.endpoint_resolver import resolve_endpoint - _fallback = resolve_endpoint("default", owner=owner) - if _fallback and _fallback[1]: - t_url, t_model, t_headers = _fallback - else: - t_url, t_model, t_headers = sess.endpoint_url, sess.model, sess.headers + t_url, t_model, t_headers = resolve_task_endpoint( + sess.endpoint_url, sess.model, sess.headers, owner=owner + ) if not t_model: logger.debug("[auto-name] No model provided, skipping") return @@ -576,7 +568,8 @@ async def build_chat_context( if not incognito: fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode) - # Resolve user prefs + # Resolve owner-scoped prefs/context. Browser requests keep the cookie user; + # bearer-token chat requests use the token owner instead of the "api" sentinel. user = effective_user(request) uprefs = load_prefs_for_user(user) diff --git a/routes/memory_routes.py b/routes/memory_routes.py index e788f82d2..d290046ec 100644 --- a/routes/memory_routes.py +++ b/routes/memory_routes.py @@ -273,65 +273,30 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM async def api_audit_memories(request: Request, session: str = Form(None)): """Deduplicate and consolidate memories via LLM. - Uses the default model from settings, or falls back to a session's model. + Uses task/utility/default settings through the shared resolver, with + the active session as fallback when no task or utility model is set. Returns before and after memory counts. """ - from routes.model_routes import _load_settings, _normalize_base, build_chat_url - from core.database import ModelEndpoint - import json as _json - - endpoint_url = model = None - headers = {} - - # Try utility model from settings first — memory audit is a background - # task and should prefer the lighter utility model over the main chat model. - from src.task_endpoint import resolve_task_endpoint user = _owner(request) - t_url, t_model, t_headers = resolve_task_endpoint(owner=user) - if t_url and t_model: - endpoint_url, model, headers = t_url, t_model, t_headers - else: - # Fall back to default model if no task/utility model configured - settings = _load_settings() - ep_id = settings.get("default_endpoint_id", "") - default_model = settings.get("default_model", "") - if ep_id: - db = SessionLocal() - try: - ep = db.query(ModelEndpoint).filter( - ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True - ).first() - if ep: - base = _normalize_base(ep.base_url) - endpoint_url = build_chat_url(base) - model = default_model - if not model and ep.models: - try: - models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models - if models: - model = models[0] - except Exception: - pass - if ep.api_key: - headers = {"Authorization": f"Bearer {ep.api_key}"} - finally: - db.close() + fallback_url = fallback_model = None + fallback_headers = None + if session: + try: + sess = session_manager.get_session(session) + _assert_session_owner(sess, user) + fallback_url = sess.endpoint_url + fallback_model = sess.model + fallback_headers = sess.headers + except KeyError: + pass - # Fall back to session model if no default configured - if not endpoint_url and session: - try: - sess = session_manager.get_session(session) - _assert_session_owner(sess, _owner(request)) - endpoint_url = sess.endpoint_url - model = sess.model - headers = sess.headers - except KeyError: - pass + endpoint_url, model, headers = resolve_task_endpoint( + fallback_url, fallback_model, fallback_headers, owner=user + ) if not endpoint_url or not model: raise HTTPException(400, "No default model configured — set one in Settings") - user = _owner(request) result = await audit_memories( memory_manager, memory_vector, @@ -369,18 +334,28 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM model = None headers = {} + user = _owner(request) + if session: try: sess = session_manager.get_session(session) - _assert_session_owner(sess, _owner(request)) - endpoint_url, model, headers = resolve_task_endpoint( - sess.endpoint_url, sess.model, sess.headers, owner=_owner(request) - ) + _assert_session_owner(sess, user) except KeyError: - logger.warning("Session %s not found, falling back to utility endpoint", session) - endpoint_url, model, headers = resolve_endpoint("utility", owner=_owner(request)) + sess = None + except HTTPException as exc: + if exc.status_code != 404: + raise + sess = None + + if sess is None: + logger.warning("Session %s not found or inaccessible, falling back to utility endpoint", session) + endpoint_url, model, headers = resolve_endpoint("utility", owner=user) + else: + endpoint_url, model, headers = resolve_task_endpoint( + sess.endpoint_url, sess.model, sess.headers, owner=user + ) else: - endpoint_url, model, headers = resolve_task_endpoint(owner=_owner(request)) + endpoint_url, model, headers = resolve_task_endpoint(owner=user) if not endpoint_url or not model: raise HTTPException(400, "No LLM model configured. Set a default model in Settings.") diff --git a/tests/test_chat_helpers.py b/tests/test_chat_helpers.py index 6b3ec87e0..92bd51561 100644 --- a/tests/test_chat_helpers.py +++ b/tests/test_chat_helpers.py @@ -1,10 +1,19 @@ +import asyncio +from types import SimpleNamespace + import pytest from fastapi import HTTPException +import routes.chat_helpers as chat_helpers from routes.chat_helpers import ( _enforce_chat_privileges, + _session_is_research_spinoff, + auto_name_session, + build_chat_context, clean_thinking_for_save, needs_auto_name, + PreprocessedMessage, + PresetInfo, save_assistant_response, ) @@ -220,10 +229,6 @@ def test_save_assistant_response_preserves_actual_and_requested_model(): assert sess.history[-1].metadata["model"] == "actual-model" -from types import SimpleNamespace -from routes.chat_helpers import _session_is_research_spinoff - - class _SpinMsg: def __init__(self, role, metadata=None): self.role = role @@ -238,6 +243,57 @@ def test_spinoff_detected_from_chatmessage_history(): assert _session_is_research_spinoff(sess) is True +def test_auto_name_session_passes_session_fallback_to_task_resolver(monkeypatch): + import src.llm_core as llm_core + import src.task_endpoint as task_endpoint + + resolver_calls = [] + llm_calls = [] + + def fake_resolve_task_endpoint( + fallback_url=None, + fallback_model=None, + fallback_headers=None, + owner=None, + ): + resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner)) + return fallback_url, fallback_model, fallback_headers + + async def fake_llm_call(url, model, messages, **kwargs): + llm_calls.append((url, model, messages, kwargs)) + return "Focused Fix" + + monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint) + monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call) + + session_headers = {"Authorization": "Bearer session"} + sess = SimpleNamespace( + id="session-1", + owner="alice", + endpoint_url="http://session.example/v1/chat/completions", + model="session-model", + headers=session_headers, + history=[SimpleNamespace(role="user", content="Please fix the endpoint fallback bug.")], + ) + updates = [] + session_manager = SimpleNamespace( + update_session_name=lambda session_id, title: updates.append((session_id, title)) + ) + + asyncio.run(auto_name_session(session_manager, sess)) + + assert resolver_calls == [( + "http://session.example/v1/chat/completions", + "session-model", + session_headers, + "alice", + )] + assert llm_calls[0][0] == "http://session.example/v1/chat/completions" + assert llm_calls[0][1] == "session-model" + assert llm_calls[0][3]["headers"] == session_headers + assert updates == [("session-1", "Focused Fix")] + + def test_spinoff_detected_from_dict_history(): sess = SimpleNamespace(history=[ {"role": "system", "metadata": {"research_spinoff_from": "rp-2"}}, @@ -262,3 +318,120 @@ def test_metadata_on_non_system_message_ignored(): def test_empty_or_missing_history(): assert _session_is_research_spinoff(SimpleNamespace(history=[])) is False assert _session_is_research_spinoff(SimpleNamespace()) is False + + +async def _build_context_owner_probe(monkeypatch, request_state): + captured = { + "prefs_owner": None, + "preface_owner": None, + "compact_owner": None, + } + + async def fake_preprocess(chat_handler, message, att_ids, sess, **kwargs): + return PreprocessedMessage( + enhanced_message=message, + user_content=message, + text_for_context=message, + youtube_transcripts=[], + attachment_meta=[], + ) + + def fake_extract_preset(chat_handler, preset_id): + return PresetInfo( + temperature=0.7, + max_tokens=1024, + system_prompt=None, + character_name=None, + ) + + def fake_add_user_message(sess, chat_handler, preprocessed, incognito=False): + sess.messages.append({"role": "user", "content": preprocessed.user_content}) + + def fake_load_prefs(owner): + captured["prefs_owner"] = owner + return {"memory_enabled": True, "skills_enabled": True} + + def fake_build_context_preface(**kwargs): + captured["preface_owner"] = kwargs["owner"] + return [], [], [] + + async def fake_maybe_compact(sess, endpoint_url, model, messages, headers, owner=None): + captured["compact_owner"] = owner + return messages, 8192, False + + monkeypatch.setattr(chat_helpers, "preprocess", fake_preprocess) + monkeypatch.setattr(chat_helpers, "extract_preset", fake_extract_preset) + monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message) + monkeypatch.setattr(chat_helpers, "load_prefs_for_user", fake_load_prefs) + monkeypatch.setattr(chat_helpers, "_normalize_model_id_from_cache", lambda sess: None) + monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None) + monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact) + monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages) + + import src.user_time as user_time + + monkeypatch.setattr( + user_time, + "current_datetime_context_message", + lambda now_utc=None: {"role": "user", "content": "[Context - current date/time]"}, + raising=False, + ) + + sess = SimpleNamespace( + endpoint_url="http://model.local/v1/chat/completions", + model="test-model", + headers={}, + history=[], + messages=[], + ) + sess.get_context_messages = lambda: list(sess.messages) + + request = SimpleNamespace(state=SimpleNamespace(**request_state)) + ctx = await build_chat_context( + sess=sess, + request=request, + chat_handler=SimpleNamespace(), + chat_processor=SimpleNamespace(build_context_preface=fake_build_context_preface), + message="hello", + session_id="session-1", + incognito=True, + ) + + return ctx, captured + + +@pytest.mark.asyncio +async def test_build_chat_context_uses_api_token_owner_for_compaction_scope(monkeypatch): + ctx, captured = await _build_context_owner_probe( + monkeypatch, + { + "api_token": True, + "api_token_owner": "alice", + "current_user": "api", + }, + ) + + assert ctx.user == "alice" + assert captured == { + "prefs_owner": "alice", + "preface_owner": "alice", + "compact_owner": "alice", + } + + +@pytest.mark.asyncio +async def test_build_chat_context_keeps_cookie_user_owner_scope(monkeypatch): + ctx, captured = await _build_context_owner_probe( + monkeypatch, + { + "api_token": False, + "current_user": "bob", + }, + ) + + assert ctx.user == "bob" + assert captured == { + "prefs_owner": "bob", + "preface_owner": "bob", + "compact_owner": "bob", + } diff --git a/tests/test_memory_routes_session_owner.py b/tests/test_memory_routes_session_owner.py index be5e05e03..29a759148 100644 --- a/tests/test_memory_routes_session_owner.py +++ b/tests/test_memory_routes_session_owner.py @@ -7,11 +7,14 @@ another tenant's session and leak their chat history, session-scoped LLM credentials, or session title. """ import asyncio +import io +import sys +import types from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from fastapi import HTTPException +from fastapi import HTTPException, UploadFile import routes.memory_routes as mr from src.request_models import MemoryAddRequest @@ -46,6 +49,17 @@ def _request(user): ) +def _upload(name="memories.json"): + return UploadFile( + filename=name, + file=io.BytesIO(b'[{"text": "Project Phoenix uses Python", "category": "project"}]'), + ) + + +def _allow_memory_management(monkeypatch): + monkeypatch.setattr("src.auth_helpers.require_privilege", lambda request, privilege: "alice") + + def test_extract_rejects_other_users_session(monkeypatch): router = _router(monkeypatch, caller="bob") extract = _route(router, "/api/memory/extract", "POST") @@ -69,6 +83,78 @@ def test_owner_can_access_own_session(monkeypatch): assert out["session_name"] == "Secret project" +def test_audit_session_fallback_uses_resolver_without_manual_default(monkeypatch): + import src.task_endpoint as task_endpoint + + memory_manager = MagicMock() + memory_vector = MagicMock() + session_headers = {"Authorization": "Bearer session"} + session_manager = MagicMock() + session_manager.get_session.return_value = SimpleNamespace( + owner="alice", + endpoint_url="http://session.example/v1/chat/completions", + model="session-model", + headers=session_headers, + ) + router = mr.setup_memory_routes(memory_manager, session_manager, memory_vector) + audit_route = _route(router, "/api/memory/audit", "POST") + + resolver_calls = [] + audit_calls = [] + + def fake_resolve_task_endpoint( + fallback_url=None, + fallback_model=None, + fallback_headers=None, + owner=None, + ): + resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner)) + if fallback_url and fallback_model: + return fallback_url, fallback_model, fallback_headers + return None, None, {} + + async def fake_audit_memories(memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner=None): + audit_calls.append((memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner)) + return {"before": 2, "after": 1} + + fake_model_routes = types.ModuleType("routes.model_routes") + fake_model_routes._load_settings = lambda: { + "default_endpoint_id": "default", + "default_model": "default-model", + } + fake_model_routes._normalize_base = lambda base: base.rstrip("/") + fake_model_routes.build_chat_url = lambda base: f"{base}/chat/completions" + + monkeypatch.setattr(mr, "resolve_task_endpoint", fake_resolve_task_endpoint) + monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint) + monkeypatch.setattr(mr, "audit_memories", fake_audit_memories) + monkeypatch.setitem(sys.modules, "routes.model_routes", fake_model_routes) + monkeypatch.setattr( + mr, + "SessionLocal", + lambda: (_ for _ in ()).throw(AssertionError("manual default branch should not run")), + ) + + out = asyncio.run(audit_route(request=_request("alice"), session="session-1")) + + assert resolver_calls == [( + "http://session.example/v1/chat/completions", + "session-model", + session_headers, + "alice", + )] + assert audit_calls == [( + memory_manager, + memory_vector, + "http://session.example/v1/chat/completions", + "session-model", + session_headers, + "alice", + )] + assert out["ok"] is True + assert out["removed"] == 1 + + def test_add_memory_rejects_other_users_session(monkeypatch): memory_manager = MagicMock() session_manager = MagicMock() @@ -125,3 +211,79 @@ def test_timeline_does_not_expose_other_users_session_name(): out = timeline(request=_request("alice")) assert out["timeline"][0]["session_name"] == "Unknown" + + +def test_import_missing_session_uses_utility_fallback(monkeypatch): + _allow_memory_management(monkeypatch) + memory_manager = MagicMock() + session_manager = MagicMock() + session_manager.get_session.side_effect = KeyError + resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {})) + resolve_task_endpoint = MagicMock(side_effect=AssertionError("session task endpoint should not be used")) + monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint) + monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint) + router = mr.setup_memory_routes(memory_manager, session_manager) + import_memories = _route(router, "/api/memory/import", "POST") + + out = asyncio.run(import_memories(request=_request("alice"), session="missing-session", file=_upload())) + + assert out == { + "suggestions": [{"text": "Project Phoenix uses Python", "category": "project"}], + "filename": "memories.json", + } + session_manager.get_session.assert_called_once_with("missing-session") + resolve_endpoint.assert_called_once_with("utility", owner="alice") + + +def test_import_foreign_session_uses_same_utility_fallback(monkeypatch): + _allow_memory_management(monkeypatch) + memory_manager = MagicMock() + session_manager = MagicMock() + session_manager.get_session.return_value = SimpleNamespace( + owner="bob", + endpoint_url="http://bob-llm", + model="bob-model", + headers={"Authorization": "Bearer bob-secret"}, + ) + resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {})) + resolve_task_endpoint = MagicMock(side_effect=AssertionError("foreign session endpoint should not be used")) + monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint) + monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint) + router = mr.setup_memory_routes(memory_manager, session_manager) + import_memories = _route(router, "/api/memory/import", "POST") + + out = asyncio.run(import_memories(request=_request("alice"), session="bob-session", file=_upload())) + + assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}] + session_manager.get_session.assert_called_once_with("bob-session") + resolve_endpoint.assert_called_once_with("utility", owner="alice") + + +def test_import_owned_session_uses_session_endpoint(monkeypatch): + _allow_memory_management(monkeypatch) + memory_manager = MagicMock() + session_manager = MagicMock() + session_manager.get_session.return_value = SimpleNamespace( + owner="alice", + endpoint_url="http://alice-llm", + model="alice-model", + headers={"X-Session": "alice"}, + ) + resolve_endpoint = MagicMock(side_effect=AssertionError("utility fallback should not be used")) + resolve_task_endpoint = MagicMock(return_value=("http://alice-task", "alice-task-model", {"X-Task": "alice"})) + monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint) + monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint) + router = mr.setup_memory_routes(memory_manager, session_manager) + import_memories = _route(router, "/api/memory/import", "POST") + + out = asyncio.run(import_memories(request=_request("alice"), session="alice-session", file=_upload())) + + assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}] + session_manager.get_session.assert_called_once_with("alice-session") + resolve_task_endpoint.assert_called_once_with( + "http://alice-llm", + "alice-model", + {"X-Session": "alice"}, + owner="alice", + ) + resolve_endpoint.assert_not_called() diff --git a/tests/test_resolve_endpoint_fallbacks.py b/tests/test_resolve_endpoint_fallbacks.py index e77a83ae7..c210ecf19 100644 --- a/tests/test_resolve_endpoint_fallbacks.py +++ b/tests/test_resolve_endpoint_fallbacks.py @@ -147,6 +147,26 @@ def test_returns_explicit_fallback_when_no_endpoint_id_configured(monkeypatch): ) == fallback +def test_task_session_fallback_wins_before_default_when_task_and_utility_unset(monkeypatch): + settings = { + "task_endpoint_id": "", + "task_model": "", + "utility_endpoint_id": "", + "utility_model": "", + "default_endpoint_id": "default", + "default_model": "default-chat", + } + fallback = ("https://session.example/chat", "session-chat", {"X-Test": "session"}) + _install_resolver_fakes(monkeypatch, settings, [_endpoint("default", "default-chat")]) + + assert resolve_endpoint( + "task", + fallback_url=fallback[0], + fallback_model=fallback[1], + fallback_headers=fallback[2], + ) == fallback + + def test_hidden_configured_model_selects_first_enabled_chat_model(monkeypatch): settings = { "default_endpoint_id": "default",