Skip to content

Commit 0671a3d

Browse files
committed
Return primary key instead of offset for networkx
1 parent 24a229f commit 0671a3d

4 files changed

Lines changed: 50 additions & 35 deletions

File tree

src_py/connection.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,34 @@ def set_max_threads_for_exec(self, num_threads):
1212

1313
def execute(self, query, parameters=[]):
1414
return QueryResult(self, self._connection.execute(query, parameters))
15+
16+
def _get_node_property_names(self, table_name):
17+
PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
18+
LIST_SYMBOL = "[]"
19+
result_str = self._connection.get_node_property_names(
20+
table_name)
21+
results = {}
22+
for (i, line) in enumerate(result_str.splitlines()):
23+
# ignore first line
24+
if i == 0:
25+
continue
26+
line = line.strip()
27+
if line == "":
28+
continue
29+
line_splited = line.split(" ")
30+
if len(line_splited) < 2:
31+
continue
32+
33+
prop_name = line_splited[0]
34+
prop_type = " ".join(line_splited[1:])
35+
36+
is_primary_key = PRIMARY_KEY_SYMBOL in prop_type
37+
prop_type = prop_type.replace(PRIMARY_KEY_SYMBOL, "")
38+
dimension = prop_type.count(LIST_SYMBOL)
39+
prop_type = prop_type.replace(LIST_SYMBOL, "")
40+
results[prop_name] = {
41+
"type": prop_type,
42+
"dimension": dimension,
43+
"is_primary_key": is_primary_key
44+
}
45+
return results

src_py/query_result.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def get_as_networkx(self, directed=True):
7777
nodes = {}
7878
rels = {}
7979
table_to_label_dict = {}
80+
table_primary_key_dict = {}
81+
82+
def encode_node_id(node, table_primary_key_dict):
83+
return node['_label'] + "_" + str(node[table_primary_key_dict[node['_label']]])
8084

8185
# De-duplicate nodes and rels
8286
while self.has_next():
@@ -98,17 +102,24 @@ def get_as_networkx(self, directed=True):
98102
for node in nodes.values():
99103
_id = node["_id"]
100104
node_id = node['_label'] + "_" + str(_id["offset"])
105+
if node['_label'] not in table_primary_key_dict:
106+
props = self.connection._get_node_property_names(node['_label'])
107+
for prop_name in props:
108+
if props[prop_name]['is_primary_key']:
109+
table_primary_key_dict[node['_label']] = prop_name
110+
break
111+
node_id = encode_node_id(node, table_primary_key_dict)
101112
node[node['_label']] = True
102113
nx_graph.add_node(node_id, **node)
103114

104115
# Add rels
105116
for rel in rels.values():
106117
_src = rel["_src"]
107118
_dst = rel["_dst"]
108-
src_id = str(
109-
table_to_label_dict[_src["table"]]) + "_" + str(_src["offset"])
110-
dst_id = str(
111-
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
119+
src_node = nodes[(_src["table"], _src["offset"])]
120+
dst_node = nodes[(_dst["table"], _dst["offset"])]
121+
src_id = encode_node_id(src_node, table_primary_key_dict)
122+
dst_id = encode_node_id(dst_node, table_primary_key_dict)
112123
nx_graph.add_edge(src_id, dst_id, **rel)
113124
return nx_graph
114125

src_py/torch_geometric_result_converter.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,8 @@ def __init__(self, query_result):
2020
def __get_node_property_names(self, table_name):
2121
if table_name in self.nodes_property_names_dict:
2222
return self.nodes_property_names_dict[table_name]
23-
24-
PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
25-
LIST_SYMBOL = "[]"
26-
result_str = self.query_result.connection._connection.get_node_property_names(
23+
results = self.query_result.connection._get_node_property_names(
2724
table_name)
28-
results = {}
29-
for (i, line) in enumerate(result_str.splitlines()):
30-
# ignore first line
31-
if i == 0:
32-
continue
33-
line = line.strip()
34-
if line == "":
35-
continue
36-
line_splited = line.split(" ")
37-
if len(line_splited) < 2:
38-
continue
39-
40-
prop_name = line_splited[0]
41-
prop_type = " ".join(line_splited[1:])
42-
43-
is_primary_key = PRIMARY_KEY_SYMBOL in prop_type
44-
prop_type = prop_type.replace(PRIMARY_KEY_SYMBOL, "")
45-
dimension = prop_type.count(LIST_SYMBOL)
46-
prop_type = prop_type.replace(LIST_SYMBOL, "")
47-
results[prop_name] = {
48-
"type": prop_type,
49-
"dimension": dimension,
50-
"is_primary_key": is_primary_key
51-
}
5225
self.nodes_property_names_dict[table_name] = results
5326
return results
5427

test/test_networkx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_to_networkx_node(establish_connection):
5252
}
5353
for i in range(len(nodes)):
5454
node_id, node = nodes[i]
55-
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])
55+
assert node_id == "%s_%d" % (node['_label'], node['ID'])
5656
for key in ground_truth:
5757
assert node[key] == ground_truth[key][i]
5858

@@ -110,7 +110,7 @@ def test_networkx_undirected(establish_connection):
110110
'person', 'person'],
111111
}
112112
for (node_id, node) in nodes:
113-
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])
113+
assert node_id == "%s_%d" % (node['_label'], node['ID'])
114114

115115
for (_, node) in nodes:
116116
found = False
@@ -166,7 +166,7 @@ def test_networkx_directed(establish_connection):
166166
}
167167

168168
for (node_id, node) in nodes:
169-
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])
169+
assert node_id == "%s_%d" % (node['_label'], node['ID'])
170170

171171
for (_, node) in nodes:
172172
if 'person' not in node:

0 commit comments

Comments
 (0)