Skip to content

Commit 36b62e0

Browse files
authored
EntityResolution batch. Close infiniflow#6570 (infiniflow#6602)
### What problem does this PR solve? EntityResolution batch ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent d2043ff commit 36b62e0

2 files changed

Lines changed: 28 additions & 20 deletions

File tree

graphrag/entity_resolution.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def __init__(
6363
self._resolution_result_delimiter_key = "resolution_result_delimiter"
6464
self._input_text_key = "input_text"
6565

66-
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
66+
async def __call__(self, graph: nx.Graph,
67+
subgraph_nodes: set[str],
68+
prompt_variables: dict[str, Any] | None = None,
69+
callback: Callable | None = None) -> EntityResolutionResult:
6770
"""Call method definition."""
6871
if prompt_variables is None:
6972
prompt_variables = {}
@@ -88,16 +91,19 @@ async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | Non
8891

8992
candidate_resolution = {entity_type: [] for entity_type in entity_types}
9093
for k, v in node_clusters.items():
91-
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
94+
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
9295
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
9396
callback(msg=f"Identified {num_candidates} candidate pairs")
9497

9598
resolution_result = set()
99+
resolution_batch_size = 100
96100
async with trio.open_nursery() as nursery:
97101
for candidate_resolution_i in candidate_resolution.items():
98102
if not candidate_resolution_i[1]:
99103
continue
100-
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
104+
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
105+
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
106+
nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result))
101107
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
102108

103109
change = GraphChange()
@@ -118,7 +124,7 @@ async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | Non
118124
change=change,
119125
)
120126

121-
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
127+
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]):
122128
gen_conf = {"temperature": 0.5}
123129
pair_txt = [
124130
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']

graphrag/general/index.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,35 @@ async def run_graphrag(
6969
embedding_model,
7070
callback,
7171
)
72-
new_graph = None
73-
if subgraph:
74-
new_graph = await merge_subgraph(
75-
tenant_id,
76-
kb_id,
77-
doc_id,
78-
subgraph,
79-
embedding_model,
80-
callback,
81-
)
72+
if not subgraph:
73+
return
74+
75+
subgraph_nodes = set(subgraph.nodes())
76+
new_graph = await merge_subgraph(
77+
tenant_id,
78+
kb_id,
79+
doc_id,
80+
subgraph,
81+
embedding_model,
82+
callback,
83+
)
84+
assert new_graph is not None
8285

8386
if not with_resolution or not with_community:
8487
return
8588

86-
if new_graph is None:
87-
new_graph = await get_graph(tenant_id, kb_id)
88-
89-
if with_resolution and new_graph is not None:
89+
if with_resolution:
9090
await resolve_entities(
9191
new_graph,
92+
subgraph_nodes,
9293
tenant_id,
9394
kb_id,
9495
doc_id,
9596
chat_model,
9697
embedding_model,
9798
callback,
9899
)
99-
if with_community and new_graph is not None:
100+
if with_community:
100101
await extract_community(
101102
new_graph,
102103
tenant_id,
@@ -223,6 +224,7 @@ async def merge_subgraph(
223224

224225
async def resolve_entities(
225226
graph,
227+
subgraph_nodes: set[str],
226228
tenant_id: str,
227229
kb_id: str,
228230
doc_id: str,
@@ -241,7 +243,7 @@ async def resolve_entities(
241243
er = EntityResolution(
242244
llm_bdl,
243245
)
244-
reso = await er(graph, callback=callback)
246+
reso = await er(graph, subgraph_nodes, callback=callback)
245247
graph = reso.graph
246248
change = reso.change
247249
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")

0 commit comments

Comments
 (0)