"""Unified-diff â symbol mapping and PR-style risk scoring (B4 / PR-B).
Uses the `unidiff` library for parsing. Graph-resident symbols only; newly
added Java members are not modelled â see `notes` on the returned report.
"""
from __future__ import annotations
import re
from dataclasses import asdict, dataclass
from typing import Any
from unidiff import PatchSet
from unidiff.errors import UnidiffParseError
from ladybug_queries import SymbolHit, find_symbols_in_file_range, _row_to_symbol
@dataclass
class DiffHunk:
"""One unified-diff hunk in the *new* file coordinate system."""
target_path: str
source_path: str
target_line_start: int # inclusive, 1-based; 0 when the hunk has no new-file lines
target_line_end: int # inclusive
source_line_start: int
source_line_end: int
source_length: int = 0
target_length: int = 0
@dataclass
class ChangedSymbol:
symbol_id: str
fqn: str
kind: str # 'method' | 'type' | 'field'
change_type: str # 'added' | 'removed' | 'modified'
file: str
hunk_lines: list[int]
cross_service_callers_count: int = 0
@dataclass
class PrRiskReport:
changed_symbols: list[ChangedSymbol]
blast_radius_total: int
blast_radius_by_symbol: dict[str, int]
cross_service_callers: int
routes_touched: list[str]
risk_score: float
risk_band: str
notes: list[str]
_BINARY_DIFF_LINE = re.compile(r"^Binary files .+ differ\s*$")
# Heuristic: new Java method/ctor-looking line. Covers annotations, method-level
# generics, `default` interface methods, and return types with spaces (e.g.
# `Map m(`). Misses multi-line signatures, some compact record
# forms, and unusual annotations; `_notes_for_unindexed_additions` is best-effort.
_DECL_ADD = re.compile(
r"^\+\s*"
r"(?:(?:@[\w.]+\([^)]*\))\s+)*"
r"(?:<[^>]+>\s+)?"
r"(?:(?:public|private|protected|default|static|final|synchronized|abstract|native)\s+)*"
r"(.+?)\s+(\w+)\s*\(",
)
def _strip_ab_prefix(path: str) -> str:
p = path.strip()
if p.startswith(("a/", "b/")):
return p[2:]
return p
def _hunk_ranges(h: Any) -> tuple[tuple[int, int], tuple[int, int]]:
"""Return ((src_start, src_end inclusive), (tgt_start, tgt_end inclusive))."""
src_len = int(getattr(h, "source_length", 0) or 0)
tgt_len = int(getattr(h, "target_length", 0) or 0)
src_start = int(getattr(h, "source_start", 0) or 0)
tgt_start = int(getattr(h, "target_start", 0) or 0)
if src_len <= 0:
src_start, src_end = 0, 0
else:
src_end = src_start + src_len - 1
if tgt_len <= 0:
tgt_start, tgt_end = 0, 0
else:
tgt_end = tgt_start + tgt_len - 1
return (src_start, src_end), (tgt_start, tgt_end)
def parse_unified_diff(diff_text: str) -> list[DiffHunk]:
"""Parse `diff_text` into logical hunks (non-binary, non-rename files only)."""
if not (diff_text or "").strip():
return []
try:
patches = PatchSet(diff_text.splitlines(keepends=True))
except UnidiffParseError:
return []
out: list[DiffHunk] = []
for pf in patches:
if getattr(pf, "is_rename", False):
continue
tgt = _strip_ab_prefix(str(pf.path or ""))
src = _strip_ab_prefix(str(getattr(pf, "source_file", "") or pf.path or ""))
if not tgt:
continue
for h in pf:
(s0, s1), (t0, t1) = _hunk_ranges(h)
sl = int(getattr(h, "source_length", 0) or 0)
tl = int(getattr(h, "target_length", 0) or 0)
out.append(
DiffHunk(
target_path=tgt,
source_path=src,
target_line_start=t0,
target_line_end=t1,
source_line_start=s0,
source_line_end=s1,
source_length=sl,
target_length=tl,
)
)
return out
def collect_diff_file_notes(diff_text: str) -> list[str]:
"""Collect human-readable notes for binary diffs and renames (no crash)."""
notes: list[str] = []
if not (diff_text or "").strip():
return notes
for line in diff_text.splitlines():
if _BINARY_DIFF_LINE.match(line):
notes.append(f"skipped binary diff: {line.strip()}")
try:
patches = PatchSet(diff_text.splitlines(keepends=True))
except UnidiffParseError:
notes.append("diff text could not be fully parsed as a unified patch")
return notes
for pf in patches:
if getattr(pf, "is_rename", False):
a = _strip_ab_prefix(str(getattr(pf, "source_file", "") or ""))
b = _strip_ab_prefix(str(pf.path or ""))
notes.append(f"rename (symbols not mapped): {a} -> {b}")
return notes
def _resolve_graph_filename(
graph: Any,
path: str,
*,
ambiguity_notes: list[str] | None = None,
) -> str | None:
"""Map a diff path to `Symbol.filename` values stored in LadybugDB."""
variants = {_strip_ab_prefix(path)}
for v in list(variants):
if v.startswith("./"):
variants.add(v[2:])
for candidate in variants:
if not candidate:
continue
rows = graph._rows(
"MATCH (s:Symbol) WHERE s.filename = $fn RETURN s.filename AS fn LIMIT 1",
{"fn": candidate},
)
if rows and rows[0].get("fn"):
return str(rows[0]["fn"])
tail = path.strip().split("/")[-1]
if tail:
rows = graph._rows(
"MATCH (s:Symbol) WHERE s.filename ENDS WITH $tail "
"RETURN DISTINCT s.filename AS fn LIMIT 8",
{"tail": "/" + tail},
)
n = len(rows)
if n > 1 and ambiguity_notes is not None:
fns = [str(r.get("fn") or "") for r in rows if r.get("fn")]
ambiguity_notes.append(
f"ambiguous filename tail {tail!r} ({n} graph paths); "
f"ENDS WITH resolution skipped ({', '.join(fns[:4])}"
f"{'â¦' if len(fns) > 4 else ''})",
)
if n == 1 and rows[0].get("fn"):
return str(rows[0]["fn"])
return None
def _symbol_to_changed(
sym: SymbolHit,
*,
change_type: str,
lines: list[int],
) -> ChangedSymbol:
kind = sym.kind
if kind in ("class", "interface", "enum", "record", "annotation"):
mapped_kind = "type"
elif kind == "field":
mapped_kind = "field"
elif kind == "constructor":
mapped_kind = "method"
else:
mapped_kind = "method"
uniq = sorted({int(x) for x in lines if int(x) > 0})
return ChangedSymbol(
symbol_id=sym.id,
fqn=sym.fqn,
kind=mapped_kind,
change_type=change_type,
file=sym.filename,
hunk_lines=uniq,
)
def _decl_added_lines_for_file(diff_text: str, resolved_filename: str) -> int:
"""Count `+` lines in the diff that look like Java member declarations for one file."""
lines = diff_text.splitlines()
in_file = False
n = 0
for line in lines:
if line.startswith("+++ "):
rest = line[4:].strip()
if rest.startswith("b/"):
rest = rest[2:]
in_file = rest.endswith(resolved_filename) or resolved_filename.endswith(rest)
continue
if not in_file:
continue
if _DECL_ADD.match(line):
n += 1
return n
def _notes_for_unindexed_additions(
graph: Any,
diff_text: str,
changed: list[ChangedSymbol],
hunks: list[DiffHunk],
) -> list[str]:
"""Heuristic: added declaration lines vs indexed methods touched on the same file."""
notes: list[str] = []
if not diff_text.strip():
return notes
for h in hunks:
tgt_fn = _resolve_graph_filename(graph, h.target_path)
if not tgt_fn or h.target_line_start <= 0:
continue
decls = _decl_added_lines_for_file(diff_text, tgt_fn)
if decls <= 0:
continue
methods_here = [c for c in changed if c.kind == "method" and c.file == tgt_fn]
if decls > len(methods_here):
extra = decls - len(methods_here)
notes.append(
f"{extra} new method(s) not yet indexed; risk underestimated",
)
return notes
def map_hunks_to_symbols(
graph: Any,
hunks: list[DiffHunk],
*,
path_ambiguity_notes: list[str] | None = None,
) -> list[ChangedSymbol]:
"""Map diff hunks to overlapping `Symbol` rows (graph-resident only)."""
by_id: dict[str, ChangedSymbol] = {}
def merge(sym: ChangedSymbol) -> None:
existing = by_id.get(sym.symbol_id)
if existing is None:
by_id[sym.symbol_id] = sym
else:
if existing.change_type == "modified" or sym.change_type == "modified":
ct = "modified"
elif existing.change_type == "removed" or sym.change_type == "removed":
ct = "removed"
else:
ct = sym.change_type
merged_lines = sorted(set(existing.hunk_lines + sym.hunk_lines))
by_id[sym.symbol_id] = ChangedSymbol(
symbol_id=existing.symbol_id,
fqn=existing.fqn,
kind=existing.kind,
change_type=ct,
file=existing.file,
hunk_lines=merged_lines,
)
for h in hunks:
tgt_fn = _resolve_graph_filename(
graph, h.target_path, ambiguity_notes=path_ambiguity_notes,
)
src_fn = (
_resolve_graph_filename(
graph, h.source_path, ambiguity_notes=path_ambiguity_notes,
)
if h.source_path
else tgt_fn
)
if not tgt_fn and not src_fn:
continue
minus_only = h.target_length == 0 and h.source_length > 0
# Removed lines on old file (process before modified so mixed hunks prefer modified)
if h.source_line_start > 0 and h.source_line_end >= h.source_line_start and src_fn:
rows = find_symbols_in_file_range(
graph,
filename=src_fn,
start_line=h.source_line_start,
end_line=h.source_line_end,
)
for sym in rows:
if sym.kind == "file":
continue
overlap = list(range(
max(h.source_line_start, sym.start_line),
min(h.source_line_end, sym.end_line) + 1,
))
if minus_only:
merge(_symbol_to_changed(sym, change_type="removed", lines=overlap))
# Modified / added lines on new file
if h.target_line_start > 0 and h.target_line_end >= h.target_line_start and tgt_fn:
rows = find_symbols_in_file_range(
graph,
filename=tgt_fn,
start_line=h.target_line_start,
end_line=h.target_line_end,
)
for sym in rows:
if sym.kind == "file":
continue
merge(_symbol_to_changed(sym, change_type="modified", lines=list(range(
max(h.target_line_start, sym.start_line),
min(h.target_line_end, sym.end_line) + 1,
))))
return list(by_id.values())
def _impact_needle_for_changed(_graph: Any, fqn: str, mapped_kind: str) -> str:
"""Pick the `impact_analysis` needle: type FQN for members, else the symbol FQN."""
if mapped_kind in ("method", "field", "constructor"):
if "#" in fqn:
return fqn.split("#", 1)[0]
return fqn
def _is_public_interface_method(graph: Any, sym: SymbolHit) -> bool:
if sym.kind != "method":
return False
if "private" in (sym.modifiers or []):
return False
type_fqn = sym.fqn.split("#", 1)[0] if "#" in sym.fqn else sym.fqn
rows = graph._rows(
"MATCH (t:Symbol) WHERE t.fqn = $f AND t.kind = 'interface' RETURN t.id LIMIT 1",
{"f": type_fqn},
)
return bool(rows)
def _route_ids_for_symbol(graph: Any, symbol_id: str) -> list[str]:
# Note: LadybugDB rejects `ORDER BY r.id` together with `RETURN DISTINCT r.id` (binder loses `r`).
q = (
"MATCH (s:Symbol)-[e:EXPOSES]->(r:Route) WHERE s.id = $sid "
"RETURN r.id AS id ORDER BY id"
)
seen: set[str] = set()
out: list[str] = []
for row in graph._rows(q, {"sid": symbol_id}):
rid = str(row.get("id") or "")
if rid and rid not in seen:
seen.add(rid)
out.append(rid)
return out
def compute_risk(graph: Any, changed: list[ChangedSymbol]) -> PrRiskReport:
"""Aggregate blast radius, routes, cross-service callers, and v1 risk score.
Risk score stays in [0, 1]. Cross-service route callers add a bounded
bump (up to +1.0) after normalization so they influence rank while
preserving the public scalar contract.
"""
blast_by: dict[str, int] = {}
blast_total = 0
routes: list[str] = []
cross_total = 0
sym_cols = (
"id", "kind", "name", "fqn", "package", "module", "microservice",
"filename", "start_line", "end_line", "start_byte", "end_byte",
"modifiers", "annotations", "capabilities", "role", "signature",
"parent_id", "resolved",
)
_sym_return = ", ".join(f"s.{c} AS {c}" for c in sym_cols)
iface_hit = 0.0
enriched_changed: list[ChangedSymbol] = []
for cs in changed:
sym_row = graph._rows(
"MATCH (s:Symbol) WHERE s.id = $id RETURN " + _sym_return,
{"id": cs.symbol_id},
)
if not sym_row:
continue
row0 = sym_row[0]
if iface_hit < 1.0:
sym = _row_to_symbol(row0)
if _is_public_interface_method(graph, sym):
iface_hit = 1.0
fqn = str(row0.get("fqn") or cs.fqn)
needle = _impact_needle_for_changed(graph, fqn, cs.kind)
ia = graph.impact_analysis(needle, depth=2, limit=400)
n = len(ia)
blast_by[cs.symbol_id] = n
blast_total += n
for e in graph.find_callers(cs.fqn, depth=2, limit=400):
if (
e.src.microservice
and e.dst.microservice
and e.src.microservice != e.dst.microservice
):
cross_total += 1
cs_cross_service = 0
route_ids = _route_ids_for_symbol(graph, cs.symbol_id)
for rid in route_ids:
if rid not in routes:
routes.append(rid)
callers = graph._rows(
"MATCH (s:Symbol)-[:DECLARES_CLIENT]->(c:Client)-[e:HTTP_CALLS]->(r:Route {id: $rid}) "
"WHERE e.match = 'cross_service' "
"RETURN c.id AS id LIMIT 500",
{"rid": rid},
)
callers += graph._rows(
"MATCH (s:Symbol)-[:DECLARES_PRODUCER]->(p:Producer)-[e:ASYNC_CALLS]->(r:Route {id: $rid}) "
"WHERE e.match = 'cross_service' "
"RETURN p.id AS id LIMIT 500",
{"rid": rid},
)
cs_cross_service += len(callers)
enriched_changed.append(
ChangedSymbol(
symbol_id=cs.symbol_id,
fqn=cs.fqn,
kind=cs.kind,
change_type=cs.change_type,
file=cs.file,
hunk_lines=list(cs.hunk_lines),
cross_service_callers_count=cs_cross_service,
),
)
def _normalize(x: float, ceiling: float) -> float:
if ceiling <= 0:
return 0.0
return min(float(x), ceiling) / ceiling
# v1 risk weights / ceilings (PR-B §1.2): intentionally simple baselines;
# these constants are expected to be tuned after real-world use â do not treat as stable.
w_blast, cap_blast = 0.4, 100.0
w_cross, cap_cross = 0.3, 20.0
w_iface = 0.2
w_routes, cap_routes = 0.1, 5.0
raw = (
w_blast * _normalize(float(blast_total), cap_blast)
+ w_cross * _normalize(float(cross_total), cap_cross)
+ w_iface * iface_hit
+ w_routes * _normalize(float(len(routes)), cap_routes)
)
cross_service_bonus = min(
5.0,
float(sum(c.cross_service_callers_count for c in enriched_changed)),
)
score = max(0.0, min(1.0, raw + (cross_service_bonus / 5.0)))
if score < 0.3:
band = "low"
elif score < 0.7:
band = "medium"
else:
band = "high"
return PrRiskReport(
changed_symbols=list(enriched_changed),
blast_radius_total=blast_total,
blast_radius_by_symbol=blast_by,
cross_service_callers=cross_total,
routes_touched=routes,
risk_score=score,
risk_band=band,
notes=[],
)
def pr_report_to_dict(rep: PrRiskReport) -> dict[str, Any]:
return {
"changed_symbols": [asdict(c) for c in rep.changed_symbols],
"blast_radius_total": rep.blast_radius_total,
"blast_radius_by_symbol": dict(rep.blast_radius_by_symbol),
"cross_service_callers": rep.cross_service_callers,
"routes_touched": list(rep.routes_touched),
"risk_score": rep.risk_score,
"risk_band": rep.risk_band,
"notes": list(rep.notes),
}
def analyze_pr_pipeline(graph: Any, diff_unified: str) -> PrRiskReport:
"""Full PR-B pipeline: parse â notes â map â risk."""
notes = collect_diff_file_notes(diff_unified)
hunks = parse_unified_diff(diff_unified)
path_amb: list[str] = []
changed = map_hunks_to_symbols(graph, hunks, path_ambiguity_notes=path_amb)
notes.extend(path_amb)
notes.extend(_notes_for_unindexed_additions(graph, diff_unified, changed, hunks))
rep = compute_risk(graph, changed)
merged = list(dict.fromkeys([*notes, *rep.notes]))
return PrRiskReport(
changed_symbols=rep.changed_symbols,
blast_radius_total=rep.blast_radius_total,
blast_radius_by_symbol=rep.blast_radius_by_symbol,
cross_service_callers=rep.cross_service_callers,
routes_touched=rep.routes_touched,
risk_score=rep.risk_score,
risk_band=rep.risk_band,
notes=merged,
)