fix(memory): reject ambiguous multi-object outputs during skill extraction (#3985)

This commit is contained in:
Vishnu
2026-06-15 16:14:43 +05:30
committed by GitHub
parent 8fe98cf471
commit 933ec8fec9
3 changed files with 87 additions and 26 deletions
+42 -26
View File
@@ -66,41 +66,57 @@ def _has_duplicate_title(skills, title: str) -> bool:
def _extract_json_object(text: str) -> Optional[dict]:
"""Best-effort extraction of a JSON object from an LLM response.
The response may be wrapped in code fences or surrounded by prose, and some
models emit a stray brace in the prose before the real object
(e.g. "uses {placeholder} then {...}"). Slicing first-'{' .. last-'}' then
grabs an unparseable span and the skill is silently lost. Try the whole
string first, then each '{' start position in turn, returning the first
candidate that parses to a JSON object (dict). Returns None if none do.
The response may be wrapped in code fences or surrounded by prose. Uses
json.JSONDecoder().raw_decode() to locate the boundaries of complete JSON
objects starting at each '{' position. Nested objects are filtered out to
keep only top-level candidates. If multiple non-overlapping valid JSON
objects are found, it is treated as ambiguous and returns None. Otherwise,
returns the single valid candidate dictionary.
"""
if not text:
return None
s = text.strip()
if s.startswith("```"):
s = s.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
end = s.rfind("}")
if end == -1:
decoder = json.JSONDecoder()
candidates = []
start = s.find("{")
while start != -1:
try:
obj, idx = decoder.raw_decode(s[start:])
end_pos = start + idx
if isinstance(obj, dict):
candidates.append((start, end_pos, obj))
except (json.JSONDecodeError, ValueError):
pass
start = s.find("{", start + 1)
# Filter out nested candidates to identify top-level dictionaries
top_level = []
for c in candidates:
is_nested = False
for other in candidates:
if other == c:
continue
if other[0] <= c[0] and c[1] <= other[1]:
is_nested = True
break
if not is_nested:
top_level.append(c)
if not top_level:
return None
def _as_dict(candidate):
try:
obj = json.loads(candidate)
except (json.JSONDecodeError, ValueError):
return None
return obj if isinstance(obj, dict) else None
if len(top_level) > 1:
logger.debug(
"[skill-extract] Found multiple non-overlapping JSON objects: %s",
[item[2].get("title") for item in top_level]
)
return None
# The clean, common case: the whole (de-fenced) string is the object.
obj = _as_dict(s)
if obj is not None:
return obj
# Otherwise scan each '{' candidate up to the last '}'.
start = s.find("{")
while 0 <= start < end:
obj = _as_dict(s[start : end + 1])
if obj is not None:
return obj
start = s.find("{", start + 1)
return None
return top_level[0][2]
async def maybe_extract_skill(
+15
View File
@@ -41,3 +41,18 @@ def test_non_object_json_returns_none():
def test_empty_input_returns_none():
assert skill_extractor._extract_json_object("") is None
def test_multiple_objects_returns_none():
# Two complete valid non-overlapping JSON objects should return None (fail closed).
resp = '{"title": "Restart", "steps": []} and {"title": "Stop", "steps": []}'
assert skill_extractor._extract_json_object(resp) is None
def test_trailing_stray_brace_is_recovered():
# A single valid JSON object followed by trailing text containing a stray brace should be recovered.
resp = '{"title": "Restart the service", "steps": ["a"]} }'
data = skill_extractor._extract_json_object(resp)
assert isinstance(data, dict)
assert data["title"] == "Restart the service"
+30
View File
@@ -115,3 +115,33 @@ async def test_maybe_extract_skill_drops_when_no_candidate_parses(monkeypatch):
assert entry is None
assert not skills_manager.added
async def test_maybe_extract_skill_drops_on_multiple_json_objects(monkeypatch):
# Two valid JSON objects should be rejected by maybe_extract_skill.
resp = (
'{"title": "Deploy runbook", "problem": "manual", "solution": "script", '
'"steps": ["build"], "tags": ["deploy"], "confidence": 0.9}\n'
'{"title": "Unrelated skill", "problem": "manual", "solution": "script", '
'"steps": ["build"], "tags": ["deploy"], "confidence": 0.9}'
)
async def fake_llm_call_async(*args, **kwargs):
return resp
monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async)
skills_manager = _FakeSkillsManager()
entry = await skill_extractor.maybe_extract_skill(
_FakeSession(),
skills_manager,
endpoint_url="http://endpoint",
model="test-model",
headers={},
round_count=3,
tool_count=3,
owner="alice",
)
assert entry is None
assert not skills_manager.added