Skip to content

Commit

Permalink
Merge branch 'aiim-research:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
xsk07 authored Aug 28, 2024
2 parents dd2860a + 8c05cbb commit 27f4d91
Show file tree
Hide file tree
Showing 10 changed files with 753 additions and 124 deletions.
822 changes: 710 additions & 112 deletions lab/1-evaluation_pipeline.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/data_analysis/future/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def draw_graph(cls, data_instance, position=None, img_store_address=None):
layout = nx.spring_layout
position = layout(G)

edge_colors = ['cyan' for u, v in G.edges()]
node_colors = ['cyan' for node in G.nodes()]
edge_colors = ['gray' for u, v in G.edges()]
node_colors = ['gray' for node in G.nodes()]

nx.draw_networkx(G=G, pos=position, node_color=node_colors, edge_color=edge_colors, with_labels=True)
nx.draw_networkx(G=G, pos=position, node_color=node_colors, edge_color=edge_colors, with_labels=False, node_size=100)

if img_store_address:
plt.savefig(img_store_address, format='svg')
Expand Down Expand Up @@ -305,9 +305,9 @@ def draw_counterfactual_actions(cls,

# Add shared nodes and edges in grey
for node in nodes_shared:
G.add_node(node, color='cyan')
G.add_node(node, color='gray')
for edge in edges_shared:
G.add_edge(*edge, color='cyan')
G.add_edge(*edge, color='gray')

# Add deleted nodes and edges in red
for node in nodes_deleted:
Expand All @@ -324,7 +324,7 @@ def draw_counterfactual_actions(cls,
edge_colors = [G[u][v]['color'] for u, v in G.edges()]
node_colors = [G.nodes[node]['color'] for node in G.nodes()]

nx.draw_networkx(G=G, pos=position, node_color=node_colors, edge_color=edge_colors, with_labels=True)
nx.draw_networkx(G=G, pos=position, node_color=node_colors, edge_color=edge_colors, with_labels=False, node_size=100)

if img_store_address:
plt.savefig(img_store_address, format='svg')
Expand Down
4 changes: 4 additions & 0 deletions src/dataset/instances/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def num_edges(self):
@property
def num_nodes(self):
return len(self.data)

@property
def is_directed(self):
return self.directed

def nodes(self):
return [ i for i in range(self.data.shape[0])]
Expand Down
8 changes: 7 additions & 1 deletion src/evaluation/future/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,10 @@ def pickle_explanations(self, store_path):

# Pickle the list into the specified file
with open(pickle_file_path, 'wb') as pickle_file:
pickle.dump(self._explanations, pickle_file)
for exp in self._explanations:
exp.input_instance._dataset = None
for inst in exp.counterfactual_instances:
inst._dataset = None

inst_cf_pairs = [(exp.input_instance, exp.counterfactual_instances[0]) for exp in self._explanations]
pickle.dump(self.inst_cf_pairs, pickle_file)
2 changes: 1 addition & 1 deletion src/explainer/future/ensemble/aggregators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_all_edge_differences(self, instance: DataInstance, explanations: List[Da
edges = [(row, col) for row, col in zip(*np.where(edge_change_freq_matrix))]

# If we are working with directed graphs
if instance.directed:
if instance.is_directed:
filtered_edges = edges
else: # if we are working with undirected graphs
filtered_edges = []
Expand Down
9 changes: 6 additions & 3 deletions src/explainer/future/ensemble/aggregators/bidirectional.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def oblivious_forward_search(self, instance, e_add, e_rem, k=5, maximum_oracle_c
# remove
i,j = e_rem.pop(0)

if instance.directed:
if instance.is_directed:
cf_candidate_matrix[i][j]=0
else:
cf_candidate_matrix[i][j]=0
Expand All @@ -128,7 +128,7 @@ def oblivious_forward_search(self, instance, e_add, e_rem, k=5, maximum_oracle_c
# add
i,j = e_add.pop(0)

if instance.directed:
if instance.is_directed:
cf_candidate_matrix[i][j]=1
else:
cf_candidate_matrix[i][j]=1
Expand Down Expand Up @@ -173,7 +173,10 @@ def oblivious_backward_search(self, instance, cf_instance, changed_edges, k=5, m
# Revert the changes on the selected edges
for i,j in edges_i:
gci[i][j] = abs(1 - gci[i][j])
gci[j][i] = abs(1 - gci[j][i])

# If the graph is undirected we need to undo the symmetrical edge too
if not instance.is_directed:
gci[j][i] = abs(1 - gci[j][i])

reduced_cf_inst = GraphInstance(id=instance.id, label=0, data=gci)
self.dataset.manipulate(reduced_cf_inst)
Expand Down
5 changes: 5 additions & 0 deletions src/explainer/future/ensemble/aggregators/frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def real_aggregate(self, explanations: List[LocalGraphCounterfactualExplanation]
if mod_freq_matrix[edge[0], edge[1]] >= freq_threshold:
intersection_matrix[edge[0], edge[1]] = abs(intersection_matrix[edge[0], edge[1]] - 1 )

# If the graphs are undirected
if not input_inst.is_directed:
# Assign to the symetrical edge the same value than to the original edge
intersection_matrix[edge[1], edge[0]] = intersection_matrix[edge[0], edge[1]] # The original edge was already modified

# Create the aggregated explanation
aggregated_instance = GraphInstance(id=input_inst.id, label=1-input_inst.label, data=intersection_matrix)
self.dataset.manipulate(aggregated_instance)
Expand Down
5 changes: 5 additions & 0 deletions src/explainer/future/ensemble/aggregators/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def real_aggregate(self, explanations: List[LocalGraphCounterfactualExplanation]
if mod_freq_matrix[edge[0], edge[1]] == exp_count:
intersection_matrix[edge[0], edge[1]] = abs(intersection_matrix[edge[0], edge[1]] - 1 )

# If the graphs are undirected
if not input_inst.is_directed:
# Assign to the symetrical edge the same value than to the original edge
intersection_matrix[edge[1], edge[0]] = intersection_matrix[edge[0], edge[1]] # The original edge was already modified

# Create the aggregated explanation
aggregated_instance = GraphInstance(id=input_inst.id, label=1-input_inst.label, data=intersection_matrix)
self.dataset.manipulate(aggregated_instance)
Expand Down
5 changes: 4 additions & 1 deletion src/explainer/future/ensemble/aggregators/rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def real_aggregate(self, explanations: List[LocalGraphCounterfactualExplanation]
adj_matrix = copy.deepcopy(input_inst.data)
# Randomly sample a number of edges equivalent to the smallest base explanation
sampled_edges = random.sample(change_edges, min_changes)
#TODO consider the case of undirected graphs

# Try to modified the chosen edges one by one until a counterfactual is found
for edge in sampled_edges:
adj_matrix[edge[0], edge[1]] = abs( adj_matrix[edge[0], edge[1]] - 1 )

if not input_inst.is_directed:
# Assign to the symetrical edge the same value than to the original edge
adj_matrix[edge[1], edge[0]] = adj_matrix[edge[0], edge[1]] # The original edge was already modified

# Creating an instance with the modified adjacency matrix
aggregated_instance = GraphInstance(id=input_inst.id,
label=0,
Expand Down
5 changes: 5 additions & 0 deletions src/explainer/future/ensemble/aggregators/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def real_aggregate(self, explanations: List[LocalGraphCounterfactualExplanation]
for edge in mod_edges:
union_matrix[edge[0], edge[1]] = abs(union_matrix[edge[0], edge[1]] - 1 )

# If the graphs are undirected
if not input_inst.is_directed:
# Assign to the symetrical edge the same value than to the original edge
union_matrix[edge[1], edge[0]] = union_matrix[edge[0], edge[1]] # The original edge was already modified

# Create the aggregated explanation
aggregated_instance = GraphInstance(id=input_inst.id, label=1-input_inst.label, data=union_matrix)
self.dataset.manipulate(aggregated_instance)
Expand Down

0 comments on commit 27f4d91

Please sign in to comment.