Skip to content

Commit

Permalink
eager node and edge optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
bardhprenkaj committed Jul 19, 2024
1 parent f6010e6 commit 4a75a71
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 99 deletions.
44 changes: 44 additions & 0 deletions config/EAGER/tcr28.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"experiment": {
"scope": "eager_tcr28",
"parameters": {
"lock_release_tout":120,
"propagate":[
{"in_sections" : ["explainers"], "params": {"fold_id": 0}},
{"in_sections" : ["do-pairs/oracle"], "params": {"fold_id": -1}},
{"in_sections": ["do-pairs/dataset"],"params": { "compose_man" : "config/snippets/datasets/centr_and_weights.json" }}
],
"expand" : { "folds" : [ "explainers"], "triplets" : true }
}
},

"do-pairs":[ {"compose_tc28": "config/JMLR/snippets/do-pairs/TCR-500-28-7_custom.json"} ],
"explainers": [
{"class": "src.explainer.generative.eager.EAGER",

"parameters": {

"models": [
{
"class": "src.explainer.generative.gans.graph.learnable_edges.model.EdgeLearnableGAN",
"parameters": {
"epochs": 5000,
"alpha": 0.7,
"generator": {
"class": "src.explainer.generative.gans.graph.learnable_edges.generators.TranslatingGenerator",
"parameters": {
"in_embed_dim": 8,
"out_embed_dim": 4,
"num_translator_layers": 1
}
}
}
}
]
}
}
],
"compose_mes" : "config/snippets/default_metrics_w_dumper.json",
"compose_strs" : "config/snippets/default_store_paths.json"
}

100 changes: 100 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
name: gretel
channels:
- pytorch
- defaults
dependencies:
- blas=1.0=mkl
- brotli-python=1.0.9=py312hd77b12b_8
- bzip2=1.0.8=h2bbff1b_6
- ca-certificates=2024.3.11=haa95532_0
- certifi=2024.2.2=py312haa95532_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cpuonly=2.0=0
- expat=2.6.2=hd77b12b_0
- filelock=3.13.1=py312haa95532_0
- freetype=2.12.1=ha860e81_0
- idna=3.7=py312haa95532_0
- intel-openmp=2023.1.0=h59b6b97_46320
- jinja2=3.1.3=py312haa95532_0
- jpeg=9e=h2bbff1b_1
- lcms2=2.12=h83e58a3_0
- lerc=3.0=hd77b12b_0
- libdeflate=1.17=h2bbff1b_1
- libffi=3.4.4=hd77b12b_1
- libjpeg-turbo=2.0.0=h196d8e1_0
- libpng=1.6.39=h8cc25b3_0
- libtiff=4.5.1=hd77b12b_0
- libuv=1.44.2=h2bbff1b_0
- libwebp-base=1.3.2=h2bbff1b_0
- lz4-c=1.9.4=h2bbff1b_1
- markupsafe=2.1.3=py312h2bbff1b_0
- mkl=2023.1.0=h6b88ed4_46358
- mkl-service=2.4.0=py312h2bbff1b_1
- mkl_fft=1.3.8=py312h2bbff1b_0
- mkl_random=1.2.4=py312h59b6b97_0
- mpmath=1.3.0=py312haa95532_0
- networkx=3.1=py312haa95532_0
- numpy=1.26.4=py312hfd52020_0
- numpy-base=1.26.4=py312h4dde369_0
- openjpeg=2.4.0=h4fc8c34_0
- openssl=3.0.13=h2bbff1b_1
- pillow=10.3.0=py312h2bbff1b_0
- pip=24.0=py312haa95532_0
- pysocks=1.7.1=py312haa95532_0
- python=3.12.3=h1d929f7_1
- pytorch=2.3.0=py3.12_cpu_0
- pytorch-mutex=1.0=cpu
- pyyaml=6.0.1=py312h2bbff1b_0
- requests=2.31.0=py312haa95532_1
- setuptools=69.5.1=py312haa95532_0
- sqlite=3.45.3=h2bbff1b_0
- sympy=1.12=py312haa95532_0
- tbb=2021.8.0=h59b6b97_0
- tk=8.6.14=h0416ee5_0
- torchaudio=2.3.0=py312_cpu
- torchvision=0.18.0=py312_cpu
- typing_extensions=4.11.0=py312haa95532_0
- urllib3=2.1.0=py312haa95532_1
- vc=14.2=h21ff451_1
- vs2015_runtime=14.27.29016=h5e58377_2
- wheel=0.43.0=py312haa95532_0
- win_inet_pton=1.1.0=py312haa95532_0
- xz=5.4.6=h8cc25b3_1
- yaml=0.2.5=he774522_0
- zlib=1.2.13=h8cc25b3_1
- zstd=1.5.5=hd43e919_2
- pip:
- aiohttp==3.9.5
- aiosignal==1.3.1
- atpublic==4.1.0
- attrs==23.2.0
- colorama==0.4.6
- contourpy==1.2.1
- cycler==0.12.1
- et-xmlfile==1.1.0
- flufl-lock==8.1.0
- fonttools==4.51.0
- frozenlist==1.4.1
- fsspec==2024.3.1
- joblib==1.4.2
- jsonc-parser==1.1.5
- jsonpickle==3.0.4
- kiwisolver==1.4.5
- matplotlib==3.8.4
- multidict==6.0.5
- openpyxl==3.1.2
- packaging==24.0
- pandas==2.2.2
- picologging==0.9.3
- psutil==5.9.8
- pyparsing==3.1.2
- python-dateutil==2.9.0.post0
- pytz==2024.1
- scikit-learn==1.4.2
- scipy==1.13.0
- six==1.16.0
- threadpoolctl==3.5.0
- torch-geometric==2.5.3
- tqdm==4.66.4
- tzdata==2024.1
- yarl==1.9.4
40 changes: 40 additions & 0 deletions src/explainer/generative/eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from src.core.factory_base import get_instance_kvargs
from src.explainer.per_cls_explainer import PerClassExplainer

from src.utils.cfg_utils import init_dflts_to_of
from src.utils.samplers.abstract_sampler import Sampler

class EAGER(PerClassExplainer):

def init(self):
super().init()

self.sampler: Sampler = get_instance_kvargs(self.local_config['parameters']['sampler']['class'],
self.local_config['parameters']['sampler']['parameters'])
self.sampler.dataset = self.dataset

def explain(self, instance):
with torch.no_grad():
res = super().explain(instance)

rec_nodes, edge_probs = dict(), dict()
for key, values in res.items():
# take the node features and edge probabilities
rec_nodes[key] = values[0]
edge_probs[key] = torch.sigmoid(values[-1])

cf_instance = self.sampler.sample(instance, self.oracle,
embedded_features=rec_nodes,
edge_probabilities=edge_probs)

return cf_instance if cf_instance else instance

def check_configuration(self):
self.set_proto_kls('src.explainer.generative.gans.graph.learnable_edges.model.EdgeLearnableGAN')
super().check_configuration()
init_dflts_to_of(self.local_config,
'sampler',
'src.utils.samplers.bernoulli.Bernoulli',
sampling_iterations=500)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def __init__(self, num_nodes, dim=2, dropout=.4):
self.dropout = dropout
self.training = False
self.fc = nn.Linear(num_nodes * dim, 1).double()

self.training = True

self.device = (
"cuda"
Expand Down Expand Up @@ -47,10 +49,10 @@ def forward(self, x):
if self.training:
x = self.add_gaussian_noise(x)

x = torch.flatten(x)
x = x.flatten()
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc(x)
x = torch.sigmoid(x).squeeze()
x = torch.sigmoid(x)

return x

Expand Down
25 changes: 14 additions & 11 deletions src/explainer/generative/gans/graph/learnable_edges/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ class TranslatingGenerator(nn.Module):
def __init__(self, k: int, node_features: int,
in_embed_dim: int=2,
out_embed_dim: int=2, num_translator_layers: int=2) -> None:
super(TranslatingGenerator, self).__init__()
self.embedder = GraphEmbedder(k, node_features, in_embed_dim)

self.embedder = GraphEmbedder(k, node_features, out_embed_dim)

translator_layers = []
emb_dim = out_embed_dim
self.translator = nn.Sequential()
emb_dim = in_embed_dim
for _ in range(num_translator_layers):
translator_layers.append(nn.Linear(emb_dim, in_embed_dim))
translator_layers.append(nn.ReLU())
emb_dim = in_embed_dim

self.translator = nn.Sequential(**translator_layers)
self.translator.append(nn.Linear(emb_dim, out_embed_dim))
self.translator.append(nn.ReLU())
emb_dim = out_embed_dim

self.training = True

self.device = (
"cuda"
if torch.cuda.is_available()
Expand All @@ -47,11 +47,14 @@ def init_weights(self):
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, node_features, edge_list, edge_weights):
cf_node_embeddings = self.embedder(node_features, edge_list, edge_weights)
def forward(self, node_features, edge_list, edge_weights, batch=None):
cf_node_embeddings = self.embedder(node_features, edge_list, edge_weights, batch)
f_node_embeddings = self.translator(cf_node_embeddings)
return f_node_embeddings

def set_training(self, training):
self.training = training


@default_cfg
def grtl_default(kls, k: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,22 @@ def init_weights(self):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
torch.nn.init.zeros_(m.bias)


def set_training(self, training):
self.training = training

def forward(self, x, edge_list, edge_attr):
def forward(self, x, edge_list, edge_attr, batch=None):
x = x.double()
edge_attr = edge_attr.double()
x = self.conv(x, edge_list, edge_attr)

if self.training:
x = self.add_gaussian_noise(x)
x = F.relu(x)
x, _, _, _, _, _ = self.pool(x, edge_list, edge_attr)
x = F.relu(x)

x, _, _, _, _, _ = self.pool(x, edge_list, edge_attr, batch)
return x

def add_gaussian_noise(self, x, sttdev=0.2):
noise = torch.randn(x.size(), device=self.device).mul_(sttdev)
return x + noise

@default_cfg
def grtl_default(kls, num_nodes, node_feature_dim, dim=2):
return {"class": kls,
Expand All @@ -74,36 +66,54 @@ def grtl_default(kls, num_nodes, node_feature_dim, dim=2):
}
}

class PreDiscriminatorEmbedder(nn.Module):

def __init__(self, num_nodes, node_feature_dim, dim=2) -> None:
# embeds the graph into node vectors
self.embedder = GraphEmbedder(num_nodes, node_feature_dim, dim)
# decodes the embedded node vectors into the original node feature space
self.node_decoder = nn.Linear(dim, node_feature_dim)
class EdgeExistanceModule(nn.Module):

def __init__(self, dim=2) -> None:
super(EdgeExistanceModule, self).__init__()
# decodes the edge embeddings (concatenation of two node vectors)
self.edge_decoder = nn.Linear(2 * dim, 1)

def forward(self, node_features, edge_list, edge_attrs) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor]:
emb_nodes: torch.Tensor = self.embedder(node_features, edge_list, edge_attrs)
def forward(self, emb_nodes, edge_list) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
def tensor_in_list(tensor, tensor_list):
return any(torch.equal(tensor, t) for t in tensor_list)

# repeat the node embeddings n times where n is the number of nodes
interleaved = torch.repeat_interleave(emb_nodes, repeats=emb_nodes.shape[0], dim=0)
repeated = emb_nodes.repeat(emb_nodes.shape[0])
repeated = emb_nodes.repeat(emb_nodes.shape[0], 1)
# where all rows are zero, then there's a self-loop, which we need to delete
loops = interleaved - repeated
non_empty_mask = loops.abs().sum(dim=0).bool()
# remove the self-loops
interleaved, repeated = interleaved[:,non_empty_mask], repeated[:,non_empty_mask]
non_empty_mask = loops.abs().sum(dim=1).bool()
# Initialize the real edges tensor
real_edges = []
for node1, node2 in list(zip(edge_list[0], edge_list[1])):
real_edges.append(torch.concat((emb_nodes[node1], emb_nodes[node2])))
# create edge embeddings
edge_embeddings = torch.concat([interleaved, repeated], dim=1)
# check if the edge exists
edge_exists = torch.empty(size=(edge_embeddings.shape[0], 1))
edge_logits = self.edge_decoder(edge_embeddings)
#new_edge_logits = edge_logits.clone()
#new_edge_logits[non_empty_mask] = 0

true_edges = torch.empty(size=(edge_embeddings.shape[0], 1))

for i, edge_embedding in enumerate(edge_embeddings):
edge_exists[i] = torch.sigmoid(self.edge_decoder(edge_embedding))
# reconstruct the node features
recons_nodes = self.node_decoder(emb_nodes)
return emb_nodes, recons_nodes, edge_embeddings, edge_exists

true_edges[i] = 1 if tensor_in_list(edge_embedding, real_edges) else 0

return edge_embeddings, edge_logits.squeeze(), true_edges.squeeze()


class NodeDecoderModule(nn.Module):

def __init__(self, num_nodes, node_feature_dim, dim=2) -> None:
super(NodeDecoderModule, self).__init__()
# embeds the graph into node vectors
self.embedder = GraphEmbedder(num_nodes, node_feature_dim, dim)
# decodes the embedded node vectors into the original node feature space
self.node_decoder = nn.Linear(dim, node_feature_dim)

def forward(self, emb_nodes) -> torch.Tensor:
return self.node_decoder(emb_nodes)

Loading

0 comments on commit 4a75a71

Please sign in to comment.