fix(routes): normalize session owner fallback helpers (#4313)

* fix(memory): normalize import session fallback

* fix(chat): use token owner for compaction scope

* fix(background): honor session endpoint fallback
This commit is contained in:
RaresKeY
2026-06-16 08:07:42 +03:00
committed by GitHub
parent d795d9a923
commit 2b519bf355
5 changed files with 398 additions and 75 deletions
+5 -12
View File
@@ -159,17 +159,9 @@ async def auto_name_session(session_manager, sess):
return return
owner = getattr(sess, "owner", None) owner = getattr(sess, "owner", None)
t_url, t_model, t_headers = resolve_task_endpoint(owner=owner) t_url, t_model, t_headers = resolve_task_endpoint(
if not t_model: sess.endpoint_url, sess.model, sess.headers, owner=owner
# If no task/utility model is configured at all, fall back to )
# the session's own model so auto-naming still works even on
# minimal setups.
from src.endpoint_resolver import resolve_endpoint
_fallback = resolve_endpoint("default", owner=owner)
if _fallback and _fallback[1]:
t_url, t_model, t_headers = _fallback
else:
t_url, t_model, t_headers = sess.endpoint_url, sess.model, sess.headers
if not t_model: if not t_model:
logger.debug("[auto-name] No model provided, skipping") logger.debug("[auto-name] No model provided, skipping")
return return
@@ -576,7 +568,8 @@ async def build_chat_context(
if not incognito: if not incognito:
fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode) fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode)
# Resolve user prefs # Resolve owner-scoped prefs/context. Browser requests keep the cookie user;
# bearer-token chat requests use the token owner instead of the "api" sentinel.
user = effective_user(request) user = effective_user(request)
uprefs = load_prefs_for_user(user) uprefs = load_prefs_for_user(user)
+30 -55
View File
@@ -273,65 +273,30 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
async def api_audit_memories(request: Request, session: str = Form(None)): async def api_audit_memories(request: Request, session: str = Form(None)):
"""Deduplicate and consolidate memories via LLM. """Deduplicate and consolidate memories via LLM.
Uses the default model from settings, or falls back to a session's model. Uses task/utility/default settings through the shared resolver, with
the active session as fallback when no task or utility model is set.
Returns before and after memory counts. Returns before and after memory counts.
""" """
from routes.model_routes import _load_settings, _normalize_base, build_chat_url
from core.database import ModelEndpoint
import json as _json
endpoint_url = model = None
headers = {}
# Try utility model from settings first — memory audit is a background
# task and should prefer the lighter utility model over the main chat model.
from src.task_endpoint import resolve_task_endpoint
user = _owner(request) user = _owner(request)
t_url, t_model, t_headers = resolve_task_endpoint(owner=user) fallback_url = fallback_model = None
if t_url and t_model: fallback_headers = None
endpoint_url, model, headers = t_url, t_model, t_headers if session:
else:
# Fall back to default model if no task/utility model configured
settings = _load_settings()
ep_id = settings.get("default_endpoint_id", "")
default_model = settings.get("default_model", "")
if ep_id:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True
).first()
if ep:
base = _normalize_base(ep.base_url)
endpoint_url = build_chat_url(base)
model = default_model
if not model and ep.models:
try:
models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models
if models:
model = models[0]
except Exception:
pass
if ep.api_key:
headers = {"Authorization": f"Bearer {ep.api_key}"}
finally:
db.close()
# Fall back to session model if no default configured
if not endpoint_url and session:
try: try:
sess = session_manager.get_session(session) sess = session_manager.get_session(session)
_assert_session_owner(sess, _owner(request)) _assert_session_owner(sess, user)
endpoint_url = sess.endpoint_url fallback_url = sess.endpoint_url
model = sess.model fallback_model = sess.model
headers = sess.headers fallback_headers = sess.headers
except KeyError: except KeyError:
pass pass
endpoint_url, model, headers = resolve_task_endpoint(
fallback_url, fallback_model, fallback_headers, owner=user
)
if not endpoint_url or not model: if not endpoint_url or not model:
raise HTTPException(400, "No default model configured — set one in Settings") raise HTTPException(400, "No default model configured — set one in Settings")
user = _owner(request)
result = await audit_memories( result = await audit_memories(
memory_manager, memory_manager,
memory_vector, memory_vector,
@@ -369,18 +334,28 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
model = None model = None
headers = {} headers = {}
user = _owner(request)
if session: if session:
try: try:
sess = session_manager.get_session(session) sess = session_manager.get_session(session)
_assert_session_owner(sess, _owner(request)) _assert_session_owner(sess, user)
endpoint_url, model, headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, owner=_owner(request)
)
except KeyError: except KeyError:
logger.warning("Session %s not found, falling back to utility endpoint", session) sess = None
endpoint_url, model, headers = resolve_endpoint("utility", owner=_owner(request)) except HTTPException as exc:
if exc.status_code != 404:
raise
sess = None
if sess is None:
logger.warning("Session %s not found or inaccessible, falling back to utility endpoint", session)
endpoint_url, model, headers = resolve_endpoint("utility", owner=user)
else: else:
endpoint_url, model, headers = resolve_task_endpoint(owner=_owner(request)) endpoint_url, model, headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, owner=user
)
else:
endpoint_url, model, headers = resolve_task_endpoint(owner=user)
if not endpoint_url or not model: if not endpoint_url or not model:
raise HTTPException(400, "No LLM model configured. Set a default model in Settings.") raise HTTPException(400, "No LLM model configured. Set a default model in Settings.")
+177 -4
View File
@@ -1,10 +1,19 @@
import asyncio
from types import SimpleNamespace
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
import routes.chat_helpers as chat_helpers
from routes.chat_helpers import ( from routes.chat_helpers import (
_enforce_chat_privileges, _enforce_chat_privileges,
_session_is_research_spinoff,
auto_name_session,
build_chat_context,
clean_thinking_for_save, clean_thinking_for_save,
needs_auto_name, needs_auto_name,
PreprocessedMessage,
PresetInfo,
save_assistant_response, save_assistant_response,
) )
@@ -220,10 +229,6 @@ def test_save_assistant_response_preserves_actual_and_requested_model():
assert sess.history[-1].metadata["model"] == "actual-model" assert sess.history[-1].metadata["model"] == "actual-model"
from types import SimpleNamespace
from routes.chat_helpers import _session_is_research_spinoff
class _SpinMsg: class _SpinMsg:
def __init__(self, role, metadata=None): def __init__(self, role, metadata=None):
self.role = role self.role = role
@@ -238,6 +243,57 @@ def test_spinoff_detected_from_chatmessage_history():
assert _session_is_research_spinoff(sess) is True assert _session_is_research_spinoff(sess) is True
def test_auto_name_session_passes_session_fallback_to_task_resolver(monkeypatch):
import src.llm_core as llm_core
import src.task_endpoint as task_endpoint
resolver_calls = []
llm_calls = []
def fake_resolve_task_endpoint(
fallback_url=None,
fallback_model=None,
fallback_headers=None,
owner=None,
):
resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner))
return fallback_url, fallback_model, fallback_headers
async def fake_llm_call(url, model, messages, **kwargs):
llm_calls.append((url, model, messages, kwargs))
return "Focused Fix"
monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call)
session_headers = {"Authorization": "Bearer session"}
sess = SimpleNamespace(
id="session-1",
owner="alice",
endpoint_url="http://session.example/v1/chat/completions",
model="session-model",
headers=session_headers,
history=[SimpleNamespace(role="user", content="Please fix the endpoint fallback bug.")],
)
updates = []
session_manager = SimpleNamespace(
update_session_name=lambda session_id, title: updates.append((session_id, title))
)
asyncio.run(auto_name_session(session_manager, sess))
assert resolver_calls == [(
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert llm_calls[0][0] == "http://session.example/v1/chat/completions"
assert llm_calls[0][1] == "session-model"
assert llm_calls[0][3]["headers"] == session_headers
assert updates == [("session-1", "Focused Fix")]
def test_spinoff_detected_from_dict_history(): def test_spinoff_detected_from_dict_history():
sess = SimpleNamespace(history=[ sess = SimpleNamespace(history=[
{"role": "system", "metadata": {"research_spinoff_from": "rp-2"}}, {"role": "system", "metadata": {"research_spinoff_from": "rp-2"}},
@@ -262,3 +318,120 @@ def test_metadata_on_non_system_message_ignored():
def test_empty_or_missing_history(): def test_empty_or_missing_history():
assert _session_is_research_spinoff(SimpleNamespace(history=[])) is False assert _session_is_research_spinoff(SimpleNamespace(history=[])) is False
assert _session_is_research_spinoff(SimpleNamespace()) is False assert _session_is_research_spinoff(SimpleNamespace()) is False
async def _build_context_owner_probe(monkeypatch, request_state):
captured = {
"prefs_owner": None,
"preface_owner": None,
"compact_owner": None,
}
async def fake_preprocess(chat_handler, message, att_ids, sess, **kwargs):
return PreprocessedMessage(
enhanced_message=message,
user_content=message,
text_for_context=message,
youtube_transcripts=[],
attachment_meta=[],
)
def fake_extract_preset(chat_handler, preset_id):
return PresetInfo(
temperature=0.7,
max_tokens=1024,
system_prompt=None,
character_name=None,
)
def fake_add_user_message(sess, chat_handler, preprocessed, incognito=False):
sess.messages.append({"role": "user", "content": preprocessed.user_content})
def fake_load_prefs(owner):
captured["prefs_owner"] = owner
return {"memory_enabled": True, "skills_enabled": True}
def fake_build_context_preface(**kwargs):
captured["preface_owner"] = kwargs["owner"]
return [], [], []
async def fake_maybe_compact(sess, endpoint_url, model, messages, headers, owner=None):
captured["compact_owner"] = owner
return messages, 8192, False
monkeypatch.setattr(chat_helpers, "preprocess", fake_preprocess)
monkeypatch.setattr(chat_helpers, "extract_preset", fake_extract_preset)
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", fake_load_prefs)
monkeypatch.setattr(chat_helpers, "_normalize_model_id_from_cache", lambda sess: None)
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
import src.user_time as user_time
monkeypatch.setattr(
user_time,
"current_datetime_context_message",
lambda now_utc=None: {"role": "user", "content": "[Context - current date/time]"},
raising=False,
)
sess = SimpleNamespace(
endpoint_url="http://model.local/v1/chat/completions",
model="test-model",
headers={},
history=[],
messages=[],
)
sess.get_context_messages = lambda: list(sess.messages)
request = SimpleNamespace(state=SimpleNamespace(**request_state))
ctx = await build_chat_context(
sess=sess,
request=request,
chat_handler=SimpleNamespace(),
chat_processor=SimpleNamespace(build_context_preface=fake_build_context_preface),
message="hello",
session_id="session-1",
incognito=True,
)
return ctx, captured
@pytest.mark.asyncio
async def test_build_chat_context_uses_api_token_owner_for_compaction_scope(monkeypatch):
ctx, captured = await _build_context_owner_probe(
monkeypatch,
{
"api_token": True,
"api_token_owner": "alice",
"current_user": "api",
},
)
assert ctx.user == "alice"
assert captured == {
"prefs_owner": "alice",
"preface_owner": "alice",
"compact_owner": "alice",
}
@pytest.mark.asyncio
async def test_build_chat_context_keeps_cookie_user_owner_scope(monkeypatch):
ctx, captured = await _build_context_owner_probe(
monkeypatch,
{
"api_token": False,
"current_user": "bob",
},
)
assert ctx.user == "bob"
assert captured == {
"prefs_owner": "bob",
"preface_owner": "bob",
"compact_owner": "bob",
}
+163 -1
View File
@@ -7,11 +7,14 @@ another tenant's session and leak their chat history, session-scoped LLM
credentials, or session title. credentials, or session title.
""" """
import asyncio import asyncio
import io
import sys
import types
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException, UploadFile
import routes.memory_routes as mr import routes.memory_routes as mr
from src.request_models import MemoryAddRequest from src.request_models import MemoryAddRequest
@@ -46,6 +49,17 @@ def _request(user):
) )
def _upload(name="memories.json"):
return UploadFile(
filename=name,
file=io.BytesIO(b'[{"text": "Project Phoenix uses Python", "category": "project"}]'),
)
def _allow_memory_management(monkeypatch):
monkeypatch.setattr("src.auth_helpers.require_privilege", lambda request, privilege: "alice")
def test_extract_rejects_other_users_session(monkeypatch): def test_extract_rejects_other_users_session(monkeypatch):
router = _router(monkeypatch, caller="bob") router = _router(monkeypatch, caller="bob")
extract = _route(router, "/api/memory/extract", "POST") extract = _route(router, "/api/memory/extract", "POST")
@@ -69,6 +83,78 @@ def test_owner_can_access_own_session(monkeypatch):
assert out["session_name"] == "Secret project" assert out["session_name"] == "Secret project"
def test_audit_session_fallback_uses_resolver_without_manual_default(monkeypatch):
import src.task_endpoint as task_endpoint
memory_manager = MagicMock()
memory_vector = MagicMock()
session_headers = {"Authorization": "Bearer session"}
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="alice",
endpoint_url="http://session.example/v1/chat/completions",
model="session-model",
headers=session_headers,
)
router = mr.setup_memory_routes(memory_manager, session_manager, memory_vector)
audit_route = _route(router, "/api/memory/audit", "POST")
resolver_calls = []
audit_calls = []
def fake_resolve_task_endpoint(
fallback_url=None,
fallback_model=None,
fallback_headers=None,
owner=None,
):
resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner))
if fallback_url and fallback_model:
return fallback_url, fallback_model, fallback_headers
return None, None, {}
async def fake_audit_memories(memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner=None):
audit_calls.append((memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner))
return {"before": 2, "after": 1}
fake_model_routes = types.ModuleType("routes.model_routes")
fake_model_routes._load_settings = lambda: {
"default_endpoint_id": "default",
"default_model": "default-model",
}
fake_model_routes._normalize_base = lambda base: base.rstrip("/")
fake_model_routes.build_chat_url = lambda base: f"{base}/chat/completions"
monkeypatch.setattr(mr, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(mr, "audit_memories", fake_audit_memories)
monkeypatch.setitem(sys.modules, "routes.model_routes", fake_model_routes)
monkeypatch.setattr(
mr,
"SessionLocal",
lambda: (_ for _ in ()).throw(AssertionError("manual default branch should not run")),
)
out = asyncio.run(audit_route(request=_request("alice"), session="session-1"))
assert resolver_calls == [(
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert audit_calls == [(
memory_manager,
memory_vector,
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert out["ok"] is True
assert out["removed"] == 1
def test_add_memory_rejects_other_users_session(monkeypatch): def test_add_memory_rejects_other_users_session(monkeypatch):
memory_manager = MagicMock() memory_manager = MagicMock()
session_manager = MagicMock() session_manager = MagicMock()
@@ -125,3 +211,79 @@ def test_timeline_does_not_expose_other_users_session_name():
out = timeline(request=_request("alice")) out = timeline(request=_request("alice"))
assert out["timeline"][0]["session_name"] == "Unknown" assert out["timeline"][0]["session_name"] == "Unknown"
def test_import_missing_session_uses_utility_fallback(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.side_effect = KeyError
resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {}))
resolve_task_endpoint = MagicMock(side_effect=AssertionError("session task endpoint should not be used"))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="missing-session", file=_upload()))
assert out == {
"suggestions": [{"text": "Project Phoenix uses Python", "category": "project"}],
"filename": "memories.json",
}
session_manager.get_session.assert_called_once_with("missing-session")
resolve_endpoint.assert_called_once_with("utility", owner="alice")
def test_import_foreign_session_uses_same_utility_fallback(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="bob",
endpoint_url="http://bob-llm",
model="bob-model",
headers={"Authorization": "Bearer bob-secret"},
)
resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {}))
resolve_task_endpoint = MagicMock(side_effect=AssertionError("foreign session endpoint should not be used"))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="bob-session", file=_upload()))
assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}]
session_manager.get_session.assert_called_once_with("bob-session")
resolve_endpoint.assert_called_once_with("utility", owner="alice")
def test_import_owned_session_uses_session_endpoint(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="alice",
endpoint_url="http://alice-llm",
model="alice-model",
headers={"X-Session": "alice"},
)
resolve_endpoint = MagicMock(side_effect=AssertionError("utility fallback should not be used"))
resolve_task_endpoint = MagicMock(return_value=("http://alice-task", "alice-task-model", {"X-Task": "alice"}))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="alice-session", file=_upload()))
assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}]
session_manager.get_session.assert_called_once_with("alice-session")
resolve_task_endpoint.assert_called_once_with(
"http://alice-llm",
"alice-model",
{"X-Session": "alice"},
owner="alice",
)
resolve_endpoint.assert_not_called()
+20
View File
@@ -147,6 +147,26 @@ def test_returns_explicit_fallback_when_no_endpoint_id_configured(monkeypatch):
) == fallback ) == fallback
def test_task_session_fallback_wins_before_default_when_task_and_utility_unset(monkeypatch):
settings = {
"task_endpoint_id": "",
"task_model": "",
"utility_endpoint_id": "",
"utility_model": "",
"default_endpoint_id": "default",
"default_model": "default-chat",
}
fallback = ("https://session.example/chat", "session-chat", {"X-Test": "session"})
_install_resolver_fakes(monkeypatch, settings, [_endpoint("default", "default-chat")])
assert resolve_endpoint(
"task",
fallback_url=fallback[0],
fallback_model=fallback[1],
fallback_headers=fallback[2],
) == fallback
def test_hidden_configured_model_selects_first_enabled_chat_model(monkeypatch): def test_hidden_configured_model_selects_first_enabled_chat_model(monkeypatch):
settings = { settings = {
"default_endpoint_id": "default", "default_endpoint_id": "default",