fix(routes): normalize session owner fallback helpers (#4313)

* fix(memory): normalize import session fallback

* fix(chat): use token owner for compaction scope

* fix(background): honor session endpoint fallback
This commit is contained in:
RaresKeY
2026-06-16 08:07:42 +03:00
committed by GitHub
parent d795d9a923
commit 2b519bf355
5 changed files with 398 additions and 75 deletions
+5 -12
View File
@@ -159,17 +159,9 @@ async def auto_name_session(session_manager, sess):
return
owner = getattr(sess, "owner", None)
t_url, t_model, t_headers = resolve_task_endpoint(owner=owner)
if not t_model:
# If no task/utility model is configured at all, fall back to
# the session's own model so auto-naming still works even on
# minimal setups.
from src.endpoint_resolver import resolve_endpoint
_fallback = resolve_endpoint("default", owner=owner)
if _fallback and _fallback[1]:
t_url, t_model, t_headers = _fallback
else:
t_url, t_model, t_headers = sess.endpoint_url, sess.model, sess.headers
t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, owner=owner
)
if not t_model:
logger.debug("[auto-name] No model provided, skipping")
return
@@ -576,7 +568,8 @@ async def build_chat_context(
if not incognito:
fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode)
# Resolve user prefs
# Resolve owner-scoped prefs/context. Browser requests keep the cookie user;
# bearer-token chat requests use the token owner instead of the "api" sentinel.
user = effective_user(request)
uprefs = load_prefs_for_user(user)
+33 -58
View File
@@ -273,65 +273,30 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
async def api_audit_memories(request: Request, session: str = Form(None)):
"""Deduplicate and consolidate memories via LLM.
Uses the default model from settings, or falls back to a session's model.
Uses task/utility/default settings through the shared resolver, with
the active session as fallback when no task or utility model is set.
Returns before and after memory counts.
"""
from routes.model_routes import _load_settings, _normalize_base, build_chat_url
from core.database import ModelEndpoint
import json as _json
endpoint_url = model = None
headers = {}
# Try utility model from settings first — memory audit is a background
# task and should prefer the lighter utility model over the main chat model.
from src.task_endpoint import resolve_task_endpoint
user = _owner(request)
t_url, t_model, t_headers = resolve_task_endpoint(owner=user)
if t_url and t_model:
endpoint_url, model, headers = t_url, t_model, t_headers
else:
# Fall back to default model if no task/utility model configured
settings = _load_settings()
ep_id = settings.get("default_endpoint_id", "")
default_model = settings.get("default_model", "")
if ep_id:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True
).first()
if ep:
base = _normalize_base(ep.base_url)
endpoint_url = build_chat_url(base)
model = default_model
if not model and ep.models:
try:
models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models
if models:
model = models[0]
except Exception:
pass
if ep.api_key:
headers = {"Authorization": f"Bearer {ep.api_key}"}
finally:
db.close()
fallback_url = fallback_model = None
fallback_headers = None
if session:
try:
sess = session_manager.get_session(session)
_assert_session_owner(sess, user)
fallback_url = sess.endpoint_url
fallback_model = sess.model
fallback_headers = sess.headers
except KeyError:
pass
# Fall back to session model if no default configured
if not endpoint_url and session:
try:
sess = session_manager.get_session(session)
_assert_session_owner(sess, _owner(request))
endpoint_url = sess.endpoint_url
model = sess.model
headers = sess.headers
except KeyError:
pass
endpoint_url, model, headers = resolve_task_endpoint(
fallback_url, fallback_model, fallback_headers, owner=user
)
if not endpoint_url or not model:
raise HTTPException(400, "No default model configured — set one in Settings")
user = _owner(request)
result = await audit_memories(
memory_manager,
memory_vector,
@@ -369,18 +334,28 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
model = None
headers = {}
user = _owner(request)
if session:
try:
sess = session_manager.get_session(session)
_assert_session_owner(sess, _owner(request))
endpoint_url, model, headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, owner=_owner(request)
)
_assert_session_owner(sess, user)
except KeyError:
logger.warning("Session %s not found, falling back to utility endpoint", session)
endpoint_url, model, headers = resolve_endpoint("utility", owner=_owner(request))
sess = None
except HTTPException as exc:
if exc.status_code != 404:
raise
sess = None
if sess is None:
logger.warning("Session %s not found or inaccessible, falling back to utility endpoint", session)
endpoint_url, model, headers = resolve_endpoint("utility", owner=user)
else:
endpoint_url, model, headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, owner=user
)
else:
endpoint_url, model, headers = resolve_task_endpoint(owner=_owner(request))
endpoint_url, model, headers = resolve_task_endpoint(owner=user)
if not endpoint_url or not model:
raise HTTPException(400, "No LLM model configured. Set a default model in Settings.")
+177 -4
View File
@@ -1,10 +1,19 @@
import asyncio
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
import routes.chat_helpers as chat_helpers
from routes.chat_helpers import (
_enforce_chat_privileges,
_session_is_research_spinoff,
auto_name_session,
build_chat_context,
clean_thinking_for_save,
needs_auto_name,
PreprocessedMessage,
PresetInfo,
save_assistant_response,
)
@@ -220,10 +229,6 @@ def test_save_assistant_response_preserves_actual_and_requested_model():
assert sess.history[-1].metadata["model"] == "actual-model"
from types import SimpleNamespace
from routes.chat_helpers import _session_is_research_spinoff
class _SpinMsg:
def __init__(self, role, metadata=None):
self.role = role
@@ -238,6 +243,57 @@ def test_spinoff_detected_from_chatmessage_history():
assert _session_is_research_spinoff(sess) is True
def test_auto_name_session_passes_session_fallback_to_task_resolver(monkeypatch):
import src.llm_core as llm_core
import src.task_endpoint as task_endpoint
resolver_calls = []
llm_calls = []
def fake_resolve_task_endpoint(
fallback_url=None,
fallback_model=None,
fallback_headers=None,
owner=None,
):
resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner))
return fallback_url, fallback_model, fallback_headers
async def fake_llm_call(url, model, messages, **kwargs):
llm_calls.append((url, model, messages, kwargs))
return "Focused Fix"
monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call)
session_headers = {"Authorization": "Bearer session"}
sess = SimpleNamespace(
id="session-1",
owner="alice",
endpoint_url="http://session.example/v1/chat/completions",
model="session-model",
headers=session_headers,
history=[SimpleNamespace(role="user", content="Please fix the endpoint fallback bug.")],
)
updates = []
session_manager = SimpleNamespace(
update_session_name=lambda session_id, title: updates.append((session_id, title))
)
asyncio.run(auto_name_session(session_manager, sess))
assert resolver_calls == [(
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert llm_calls[0][0] == "http://session.example/v1/chat/completions"
assert llm_calls[0][1] == "session-model"
assert llm_calls[0][3]["headers"] == session_headers
assert updates == [("session-1", "Focused Fix")]
def test_spinoff_detected_from_dict_history():
sess = SimpleNamespace(history=[
{"role": "system", "metadata": {"research_spinoff_from": "rp-2"}},
@@ -262,3 +318,120 @@ def test_metadata_on_non_system_message_ignored():
def test_empty_or_missing_history():
assert _session_is_research_spinoff(SimpleNamespace(history=[])) is False
assert _session_is_research_spinoff(SimpleNamespace()) is False
async def _build_context_owner_probe(monkeypatch, request_state):
captured = {
"prefs_owner": None,
"preface_owner": None,
"compact_owner": None,
}
async def fake_preprocess(chat_handler, message, att_ids, sess, **kwargs):
return PreprocessedMessage(
enhanced_message=message,
user_content=message,
text_for_context=message,
youtube_transcripts=[],
attachment_meta=[],
)
def fake_extract_preset(chat_handler, preset_id):
return PresetInfo(
temperature=0.7,
max_tokens=1024,
system_prompt=None,
character_name=None,
)
def fake_add_user_message(sess, chat_handler, preprocessed, incognito=False):
sess.messages.append({"role": "user", "content": preprocessed.user_content})
def fake_load_prefs(owner):
captured["prefs_owner"] = owner
return {"memory_enabled": True, "skills_enabled": True}
def fake_build_context_preface(**kwargs):
captured["preface_owner"] = kwargs["owner"]
return [], [], []
async def fake_maybe_compact(sess, endpoint_url, model, messages, headers, owner=None):
captured["compact_owner"] = owner
return messages, 8192, False
monkeypatch.setattr(chat_helpers, "preprocess", fake_preprocess)
monkeypatch.setattr(chat_helpers, "extract_preset", fake_extract_preset)
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", fake_load_prefs)
monkeypatch.setattr(chat_helpers, "_normalize_model_id_from_cache", lambda sess: None)
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
import src.user_time as user_time
monkeypatch.setattr(
user_time,
"current_datetime_context_message",
lambda now_utc=None: {"role": "user", "content": "[Context - current date/time]"},
raising=False,
)
sess = SimpleNamespace(
endpoint_url="http://model.local/v1/chat/completions",
model="test-model",
headers={},
history=[],
messages=[],
)
sess.get_context_messages = lambda: list(sess.messages)
request = SimpleNamespace(state=SimpleNamespace(**request_state))
ctx = await build_chat_context(
sess=sess,
request=request,
chat_handler=SimpleNamespace(),
chat_processor=SimpleNamespace(build_context_preface=fake_build_context_preface),
message="hello",
session_id="session-1",
incognito=True,
)
return ctx, captured
@pytest.mark.asyncio
async def test_build_chat_context_uses_api_token_owner_for_compaction_scope(monkeypatch):
ctx, captured = await _build_context_owner_probe(
monkeypatch,
{
"api_token": True,
"api_token_owner": "alice",
"current_user": "api",
},
)
assert ctx.user == "alice"
assert captured == {
"prefs_owner": "alice",
"preface_owner": "alice",
"compact_owner": "alice",
}
@pytest.mark.asyncio
async def test_build_chat_context_keeps_cookie_user_owner_scope(monkeypatch):
ctx, captured = await _build_context_owner_probe(
monkeypatch,
{
"api_token": False,
"current_user": "bob",
},
)
assert ctx.user == "bob"
assert captured == {
"prefs_owner": "bob",
"preface_owner": "bob",
"compact_owner": "bob",
}
+163 -1
View File
@@ -7,11 +7,14 @@ another tenant's session and leak their chat history, session-scoped LLM
credentials, or session title.
"""
import asyncio
import io
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from fastapi import HTTPException, UploadFile
import routes.memory_routes as mr
from src.request_models import MemoryAddRequest
@@ -46,6 +49,17 @@ def _request(user):
)
def _upload(name="memories.json"):
return UploadFile(
filename=name,
file=io.BytesIO(b'[{"text": "Project Phoenix uses Python", "category": "project"}]'),
)
def _allow_memory_management(monkeypatch):
monkeypatch.setattr("src.auth_helpers.require_privilege", lambda request, privilege: "alice")
def test_extract_rejects_other_users_session(monkeypatch):
router = _router(monkeypatch, caller="bob")
extract = _route(router, "/api/memory/extract", "POST")
@@ -69,6 +83,78 @@ def test_owner_can_access_own_session(monkeypatch):
assert out["session_name"] == "Secret project"
def test_audit_session_fallback_uses_resolver_without_manual_default(monkeypatch):
import src.task_endpoint as task_endpoint
memory_manager = MagicMock()
memory_vector = MagicMock()
session_headers = {"Authorization": "Bearer session"}
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="alice",
endpoint_url="http://session.example/v1/chat/completions",
model="session-model",
headers=session_headers,
)
router = mr.setup_memory_routes(memory_manager, session_manager, memory_vector)
audit_route = _route(router, "/api/memory/audit", "POST")
resolver_calls = []
audit_calls = []
def fake_resolve_task_endpoint(
fallback_url=None,
fallback_model=None,
fallback_headers=None,
owner=None,
):
resolver_calls.append((fallback_url, fallback_model, fallback_headers, owner))
if fallback_url and fallback_model:
return fallback_url, fallback_model, fallback_headers
return None, None, {}
async def fake_audit_memories(memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner=None):
audit_calls.append((memory_manager_arg, memory_vector_arg, endpoint_url, model, headers, owner))
return {"before": 2, "after": 1}
fake_model_routes = types.ModuleType("routes.model_routes")
fake_model_routes._load_settings = lambda: {
"default_endpoint_id": "default",
"default_model": "default-model",
}
fake_model_routes._normalize_base = lambda base: base.rstrip("/")
fake_model_routes.build_chat_url = lambda base: f"{base}/chat/completions"
monkeypatch.setattr(mr, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(task_endpoint, "resolve_task_endpoint", fake_resolve_task_endpoint)
monkeypatch.setattr(mr, "audit_memories", fake_audit_memories)
monkeypatch.setitem(sys.modules, "routes.model_routes", fake_model_routes)
monkeypatch.setattr(
mr,
"SessionLocal",
lambda: (_ for _ in ()).throw(AssertionError("manual default branch should not run")),
)
out = asyncio.run(audit_route(request=_request("alice"), session="session-1"))
assert resolver_calls == [(
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert audit_calls == [(
memory_manager,
memory_vector,
"http://session.example/v1/chat/completions",
"session-model",
session_headers,
"alice",
)]
assert out["ok"] is True
assert out["removed"] == 1
def test_add_memory_rejects_other_users_session(monkeypatch):
memory_manager = MagicMock()
session_manager = MagicMock()
@@ -125,3 +211,79 @@ def test_timeline_does_not_expose_other_users_session_name():
out = timeline(request=_request("alice"))
assert out["timeline"][0]["session_name"] == "Unknown"
def test_import_missing_session_uses_utility_fallback(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.side_effect = KeyError
resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {}))
resolve_task_endpoint = MagicMock(side_effect=AssertionError("session task endpoint should not be used"))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="missing-session", file=_upload()))
assert out == {
"suggestions": [{"text": "Project Phoenix uses Python", "category": "project"}],
"filename": "memories.json",
}
session_manager.get_session.assert_called_once_with("missing-session")
resolve_endpoint.assert_called_once_with("utility", owner="alice")
def test_import_foreign_session_uses_same_utility_fallback(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="bob",
endpoint_url="http://bob-llm",
model="bob-model",
headers={"Authorization": "Bearer bob-secret"},
)
resolve_endpoint = MagicMock(return_value=("http://utility", "utility-model", {}))
resolve_task_endpoint = MagicMock(side_effect=AssertionError("foreign session endpoint should not be used"))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="bob-session", file=_upload()))
assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}]
session_manager.get_session.assert_called_once_with("bob-session")
resolve_endpoint.assert_called_once_with("utility", owner="alice")
def test_import_owned_session_uses_session_endpoint(monkeypatch):
_allow_memory_management(monkeypatch)
memory_manager = MagicMock()
session_manager = MagicMock()
session_manager.get_session.return_value = SimpleNamespace(
owner="alice",
endpoint_url="http://alice-llm",
model="alice-model",
headers={"X-Session": "alice"},
)
resolve_endpoint = MagicMock(side_effect=AssertionError("utility fallback should not be used"))
resolve_task_endpoint = MagicMock(return_value=("http://alice-task", "alice-task-model", {"X-Task": "alice"}))
monkeypatch.setattr(mr, "resolve_endpoint", resolve_endpoint)
monkeypatch.setattr(mr, "resolve_task_endpoint", resolve_task_endpoint)
router = mr.setup_memory_routes(memory_manager, session_manager)
import_memories = _route(router, "/api/memory/import", "POST")
out = asyncio.run(import_memories(request=_request("alice"), session="alice-session", file=_upload()))
assert out["suggestions"] == [{"text": "Project Phoenix uses Python", "category": "project"}]
session_manager.get_session.assert_called_once_with("alice-session")
resolve_task_endpoint.assert_called_once_with(
"http://alice-llm",
"alice-model",
{"X-Session": "alice"},
owner="alice",
)
resolve_endpoint.assert_not_called()
+20
View File
@@ -147,6 +147,26 @@ def test_returns_explicit_fallback_when_no_endpoint_id_configured(monkeypatch):
) == fallback
def test_task_session_fallback_wins_before_default_when_task_and_utility_unset(monkeypatch):
settings = {
"task_endpoint_id": "",
"task_model": "",
"utility_endpoint_id": "",
"utility_model": "",
"default_endpoint_id": "default",
"default_model": "default-chat",
}
fallback = ("https://session.example/chat", "session-chat", {"X-Test": "session"})
_install_resolver_fakes(monkeypatch, settings, [_endpoint("default", "default-chat")])
assert resolve_endpoint(
"task",
fallback_url=fallback[0],
fallback_model=fallback[1],
fallback_headers=fallback[2],
) == fallback
def test_hidden_configured_model_selects_first_enabled_chat_model(monkeypatch):
settings = {
"default_endpoint_id": "default",