Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Molecule generation model (GeoDiff) #54

Closed
wants to merge 27 commits into from
Closed

Molecule generation model (GeoDiff) #54

wants to merge 27 commits into from

Conversation

natolambert
Copy link
Contributor

@natolambert natolambert commented Jun 30, 2022

Added a new model file to add functionality for this paper that does molecule generation via diffusion.

Pretrained models available are for two tasks, drugs and qm9:

model1 = DualEncoderEpsNetwork.from_pretrained("natolambert/geodiff-qm9-dualenc")
model1 = DualEncoderEpsNetwork.from_pretrained("natolambert/geodiff-drugs-dualenc")

Will work on additional examples for this model and update this PR.
Some todo items:

  • setup colab dependencies to run model: is here
  • add model tests
  • update colab
  • add documentation for model
  • rename model to MoleculeGNN
  • add dependency checks for torch_geometric<2, pytorch 1.8, and torch_scatter (a recommended installation method is in the colab)

Some comments:

  • the GNN implementations came from ConfGF
  • an issue to open after release will be to figure out how to port the existing models to the new version of pytorch geometric. The author is not sure if one can copy the parameter dict or if re-training will be needed. This will ease some of the specific requirements for this model.

@natolambert natolambert changed the title Molecule generation model (GeoDiff) [post release] Molecule generation model (GeoDiff) Jul 20, 2022
@natolambert
Copy link
Contributor Author

Doing more digging, the model architecture is based on GIN and SchNet (common graph neural networks).

There are local and global parameters of the molecule, and different components use different GNNs.

@@ -15,6 +15,7 @@

import inspect
import math
import pdb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be careful to always remove those before merging :-) Totally fine to keep them for testing though!

@patrickvonplaten
Copy link
Contributor

Hey @natolambert,

This PR looks super cool. I think we can merge it quite quickly! Leaving some feedback directly in the code. Regarding the notebook, the API: In general, the notebook looks very nice IMO. If we have to install certain dependencies so be it! Regarding the model API I would suggest to change it a bit. E.g.:

# generate geometry with model, then filter it
model_outputs = model.forward(batch, t)

# this model uses additional conditioning of the outputs depending on the current timestep
epsilon = model.get_residual(pos, sigmas[t], model_outputs)["sample"]

Do you think we can merge this into a single forward pass e.g. something like:

epsilon = model(batch, t, sigma)["sample"]

?

@@ -68,6 +68,13 @@ def __init__(
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Nice to see that "vanilla" DDPM can be used that easily

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According the author the model performs better quantitatively with a different scheduler, but visually I noticed no difference.

from torch.nn import Embedding, Linear, Module, ModuleList, Sequential

from rdkit.Chem.rdchem import BondType as BT
from torch_geometric.nn import MessagePassing, radius, radius_graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure that this model is only imported if torch_geometric or torch_scatter is present. We should therefore adapt the same logic we did for the optional transformers import.

Could you do the following:

  1. Add a is_torch_scatter_available and a is_torch_geometric_available to https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/__init__.py (just like it has been done for other dependencies)
  2. Import DualEncoderEpsNetwork only if both dependencies are available here:
    from .vae import AutoencoderKL, VQModel
    (just like we do it for LDM in pipelines here:
    if is_transformers_available():
  3. Ony import the model if the depedency is available in the main init as well:
    if is_transformers_available():
    - otherwise import the dummy class
  4. Having written 3.) run make fix-styles to automatically create the dummy class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the painpoint here is that is needs specific versions too, because this code was made before breaking changes in torch_geometric. Can I do similar functions for that?



class DualEncoderEpsNetwork(ModelMixin, ConfigMixin):
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to not forget the @register_to_config decorator here:

@register_to_config

return score_pos


class DualEncoderEpsNetwork(ModelMixin, ConfigMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the first graph network NN in this library, we should be extra careful with the naming.
Is this a universally understandable name? Do you think other graph networks would also use this architecuter? Should we make the name more generic in this case? Can we link to a paper here that defined that model architecture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to original code is above now too, I followed up with those authors asking if their's was original. https://github.com/DeepGraphLearning/ConfGF

@natolambert
Copy link
Contributor Author

I'll work through your comments soon @patrickvonplaten.
CC the original author @MinkaiXu in case he has any time to look.



class CFConv(MessagePassing):
def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you import nn from torch you can't use it as parameter/variable name, it's sure to create bugs later on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if you are redefining here CFConv from the SchNet paper, I think you should init the linear/shifted softplus/linear layers (what you put in nn) here instead of passing them as arguments, as it makes it hard to follow/find the logic of the paper in your code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you want to allow something more general, the default init for this layer should still be the original one from the paper)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah I hadn't seen this nn usage until now, I really dislike that. It's very confusing (and here is clear how copy-pasted blocks of this code were).

I think this was the result of copying code from multiple files too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE the second point, I don't feel as strongly about it. It was taken directly from a few codebases back. In this case, the classes CFConv and InteractionBlock could effectively become one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you open the SchNet paper, there is Figure 2 (I think) which describes the model arch. I think it could be interesting if you went through the figure along with your code: having CFConv and InteractionBlock separate is not a problem, but it would make CFConv easier to match with the paper it it defined the nn layers in directly. For further readers of this code, it will be clearer, IMO

src/diffusers/models/molecule_gnn.py Outdated Show resolved Hide resolved
@@ -0,0 +1,640 @@
# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall it would help if you could add class documentation of your different components and type hinting, at least in the inits and forwards

super(MultiLayerPerceptron, self).__init__()

self.dims = [input_dim] + hidden_dims
if isinstance(activation, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone passes an activation function instead of an activation function name this will fail silently

src/diffusers/models/molecule_gnn.py Outdated Show resolved Hide resolved

class SchNetEncoder(Module):
def __init__(
self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't paper default for num_interactions 3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the SchNet paper? No clue. This was copied from author.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess we can close this comment then (it is 3 in the SchNet paper, I looked it up along your code yesterday, but it's not that important, and if the other model relies on 6 as default, might be easier to keep 6)



class GINEConv(MessagePassing):
def __init__(self, nn: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are importing classes from torchgeometric, why are you redefining GINEConv instead of using the default version? (Could be worth some doc to explain what is different - the activation function choice?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed-up with author @MinkaiXu directly, and he got this from the implementation he built on. I can look a little more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the source and it's the same with the optional addition of an activation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding it in the class doc IMO

@@ -0,0 +1,640 @@
# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what the standard of this library is wrt asserts vs raising exceptions, but you might need to check this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anton-l or @patil-suraj any comment on this?


hiddens = []
conv_input = node_attr # (num_node, hidden)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is node_attr for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean with respect to the model / application? Or just in this code? Both of them I have only intermediate understanding of.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I just meant that you do

node_attr = self.node_emb(z)
conv_attr = node_attr

why not directly

conv_input = self.node_emb(z)

?

src/diffusers/models/molecule_gnn.py Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

We should be able to add it now that things are calmer 🥳

@georgosgeorgos
Copy link

@natolambert @clefourrier @patrickvonplaten any plan to integrate this PR?

@MinkaiXu
Copy link

MinkaiXu commented Sep 8, 2022

I took a quick look and everything is pretty good (Thank you all for your efforts!)
Ping me if there is any part I need to take a deeper look before merging the PR :)

Specifically thanks @natolambert a lot for discussing many details together!

@natolambert
Copy link
Contributor Author

@georgosgeorgos and @MinkaiXu, sorry for the delay on merging. We got a little distracted by Stable Diffusion. I'll plan on merging the updates on main to this + the notebook then we should be able to close the PR soon.

The colab should run in its current form! So if you look at that, happy to take comments!

@github-actions
Copy link

github-actions bot commented Oct 3, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 3, 2022
@natolambert natolambert removed the stale Issues that haven't received updates label Oct 3, 2022
@natolambert
Copy link
Contributor Author

Fighting the stale bot! We haven't forgotten about this, and actually moved it up the priority list today. Soon!

@natolambert natolambert changed the title [post release] Molecule generation model (GeoDiff) Molecule generation model (GeoDiff) Oct 3, 2022
@MinkaiXu
Copy link

MinkaiXu commented Oct 3, 2022

I'm glad to help, but not quite familiar with the whole process --- safety_checker.py is some mandatory step for diffusers? If so, is there a source code for this file?

@natolambert
Copy link
Contributor Author

@MinkaiXu, @patrickvonplaten moved fast and removed it here. An interesting little difference I am not aware of :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@natolambert
Copy link
Contributor Author

@patrickvonplaten @anton-l @patil-suraj: so I realized the tests I implemented are not that useful because in order to test them you need:

  • pytorch 1.8,
  • torch_geometric 1.7.2

These are kind of a pain to install with conda install -c rusty1s pytorch-geometric=1.7.2 after installing pytorch==1.8 from source.

How should we think about integrating these tests?

@patil-suraj
Copy link
Contributor

Why do we need to install pytorch from source ? Also is torch_geometric not available through pip ?

@natolambert
Copy link
Contributor Author

natolambert commented Oct 5, 2022

@patil-suraj TLDR is re-implemented a research paper's code and it was made on a version before a lot of breaking changes. Has made this pr-for-a-colab a bit unwieldy.

I'm actually not sure it any version of torch_geometric is available on pip, I don't think so.

@patrickvonplaten
Copy link
Contributor

Do we already have a working google colab for this model.
Also if a model is simply to difficult to implement, we might also only add the part that is easy to add to diffusers and for the rest just rely on an example

@natolambert
Copy link
Contributor Author

natolambert commented Oct 7, 2022

Oh yeah I can just put it all in the colab and not port it into diffusers. May be easier.

In summary, I would merge in sigmoid noise schedule and then just copy the model into the colab. Thoughts?

@patrickvonplaten
Copy link
Contributor

sounds good!

@natolambert natolambert mentioned this pull request Oct 7, 2022
@natolambert natolambert closed this Oct 7, 2022
@natolambert
Copy link
Contributor Author

We closed this PR in favor of an colab-only solution unless someone has time to update the source model to the new versions of torch_geometric so it doesn't require very different dependencies!

@georgosgeorgos
Copy link

Do we already have a working google colab for this model. Also if a model is simply to difficult to implement, we might also only add the part that is easy to add to diffusers and for the rest just rely on an example

@patrickvonplaten @natolambert where is the working colab with geodiff in the diffusers repo?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 10, 2022

Let's leave it open until we have a colab version. We can also make use of community pipelines: https://huggingface.co/docs/diffusers/using-diffusers/custom_pipelines :-)

@natolambert
Copy link
Contributor Author

Created a PR in notebooks and the colab is here -- need to update the link once the PR in notebooks is merged.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 4, 2022
@huggingface huggingface deleted a comment from github-actions bot Nov 4, 2022
@anton-l anton-l added wip and removed stale Issues that haven't received updates labels Nov 4, 2022
Dango233 pushed a commit to Dango233/diffusers that referenced this pull request Dec 9, 2022
@rish-16
Copy link

rish-16 commented Feb 15, 2023

Hey! Thanks for adding GeoDiff into the pipeline :D

TorsionDiff (https://arxiv.org/abs/2206.01729) might be another cool approach for the library. It requires fewer diffusion steps for conformational generation than GeoDiff!

code: https://github.com/gcorso/torsional-diffusion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants