Problem when passing a SparseTensor to PyG GCNconv #382
Description
The problem occured when I try to pass a SparseTensor to PyG GCNconv. I'm working with python 3.10, cuda 12.1, torch 2.2.0, PyG 2.5.2 and torch_sparse 0.6.18 installed by conda on a ubuntu server, then things didn't work well. No matter how I change the way to create the SparseTensor object, the problem just persists. I'm wondering whether the problem comes from some version compatibility issues or there's something wrong in my environment setting(very simple because I just installed torch pyg and torch_sparse). Does anyone meet similar problem or get some idea on why this issue takes place?
I think you can reproduce the issue by running following code:
def test():
ei = torch.tensor([[2, 3, 4], [1, 2, 3]]).cuda(0)
sp = SparseTensor.from_edge_index(ei, sparse_sizes=(5, 5))
model = GCNConv(2, 2).cuda(0)
x = torch.tensor([[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]).float().cuda(0)
print(x, sp)
model(x, sp)
print('success')
test()
Here is the error message:
Traceback (most recent call last):
File "/.../debug.py", line 15, in
test()
File "/.../debug.py", line 13, in test
model(x, sp)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 252, in forward
edge_index = gcn_norm( # yapf: disable
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 64, in gcn_norm
adj_t = torch_sparse.fill_diag(adj_t, fill_value)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 92, in fill_diag
return set_diag(src, value.new_full(sizes, fill_value), k)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 49, in set_diag
new_row[mask] = row
RuntimeError: shape mismatch: value tensor of shape [3] cannot be broadcast to indexing result of shape [0]