-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add UNet 1d for RL model for planning + colab #105
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Hi, is there any updates on this? |
Hi @ezhang7423, we have mostly paused merging this until Stable Diffusion calms down (which can be changed if there is interest!). That being said, this PR is very stable. What are you interested in? Have you looked at the colab? Happy to discuss anything here! |
Thank you so much for the helpful response! I'm curious about whether the 1D convolution class is actually necessary- theoretically, I think you could do everything with a 2D convolution with a non square kernel. Would it be easier to just create a wrapper class on top of the existing models that does this? I'm also wondering if you have the code for training one of these models from scratch. |
Ah, so this model is mostly a direct port from the original code. We would definitely welcome community contribution to help re-use the existing model components. Part of the reason we did not is because training is complex. It involves a lot more RL infrastructure we don't really intend to support in diffusers. I think the first step would be to re-train in another repo with diffusers as a dependency. I'm happy to discuss further if this is something you want to work on, but it's not really on my runway (because I know it will be very challenging, given my experience in RL). |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Fighting the stale bot! We haven't forgotten about this, and actually moved it up the priority list today. Soon! |
Okay, I just updated the Colab too so it works. Now a couple final thoughts:
|
src/diffusers/models/unet_rl.py
Outdated
t = self.time_mlp(timestep) | ||
h = [] | ||
|
||
for resnet, resnet2, downsample in self.downs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's really try to mirror the design in src/diffusers/models/unet_2d.py
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try and make it work with the get_block
function. The parameter re-naming script is getting quite ugly so we'll see how far I make it 🤗.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we try to morph the design more into what we have in unet_2d.py
?
Also to be the unet looks quite general, so instead of having it in a specific unet_rl.py
, I think we could put it in a more general unet_1d.py
class no? What do you think @natolambert ?
I agree. Was expecting some time on improving code quality, I just wanted to know where to start @patrickvonplaten. I will take a deeper look and ask questions when I'm stuck. Most of the code is still in the form of the original. |
@patrickvonplaten regarding your last comment above too, yes I think this can become The tricky thing is that the model here uses three dimensional tensors [batch x time_dimension x state_dimension], so that may be different than some 1d use-cases. Maybe we have some optional batching? Or, the RL model becomes a wrapper of a 1d unet? |
If it's a very special model, then indeed let's put it in its own unet class :-) Maybe easier to begin with anyways |
The documentation is not available anymore as the PR was closed or merged. |
@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie | |||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ | | |||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | |||
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | |||
|
|||
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
import d4rl # noqa | ||
import gym | ||
import tqdm | ||
from diffusers.experimental import ValueGuidedRLPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect :-)
src/diffusers/models/embeddings.py
Outdated
self.act = None | ||
if act_fn == "silu": | ||
self.act = nn.SiLU() | ||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) | ||
if act_fn == "mish": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if act_fn == "mish": | |
elif act_fn == "mish": |
@@ -62,14 +62,21 @@ def get_timestep_embedding( | |||
|
|||
|
|||
class TimestepEmbedding(nn.Module): | |||
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): | |||
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never use the keyword argument "channel=" anywhere no?
src/diffusers/models/resnet.py
Outdated
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): | ||
super().__init__() | ||
|
||
self.blocks = nn.ModuleList( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a big fan of using blocks here as it makes it very hard to adapt the class to future models. Sorry this is a bit annoying to do when converting the checkpoint, but could we maybe instead call the layers self.conv_in
and self.conv_out
? The idea here is that if in the future there are checkpoints which have an intermediate conv layer it'd be much easier to adapt this layer without breaking the previous checkpoitns
@@ -1,3 +1,17 @@ | |||
# Copyright 2022 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
@@ -54,16 +75,20 @@ def __init__( | |||
out_channels: int = 2, | |||
extra_in_channels: int = 0, | |||
time_embedding_type: str = "fourier", | |||
freq_shift: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This we cannot remove anymore sadly as we have dance diffusion now which relies on this config parameter https://huggingface.co/harmonai/maestro-150k/blob/main/unet/config.json#L35 -> could we leave the name as is here please :-) Otherwise we'd break dance diffusion 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you feel super strongly about it we could also adapt the config of dance diffusion online and leave as is
src/diffusers/models/unet_1d.py
Outdated
flip_sin_to_cos: bool = True, | ||
use_timestep_embedding: bool = False, | ||
downscale_freq_shift: float = 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll have to use freq_shift
I'm afraid
src/diffusers/models/unet_1d.py
Outdated
act_fn: str = None, | ||
norm_num_groups: int = 8, | ||
layers_per_block: int = 1, | ||
always_downsample: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a huge fan of the naming here "always" - what does "always" mean. Maybe "downsample_each_block" or something instead?
|
||
# 5. post-process | ||
if self.out_block: | ||
sample = self.out_block(sample, timestep_embed) | ||
|
||
if not return_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool adaptions here! We just need to make sure the dance diffusion tests still passes :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I had been testing along the way, will double check before merge!
for layer in self.final_block: | ||
hidden_states = layer(hidden_states) | ||
|
||
return hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that works for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice naming :-)
@@ -374,11 +598,71 @@ def __init__(self, in_channels, out_channels, mid_channels=None): | |||
|
|||
self.resnets = nn.ModuleList(resnets) | |||
|
|||
def forward(self, hidden_states, res_hidden_states_tuple): | |||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me!
@@ -204,6 +204,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): | |||
# for rl-diffuser https://arxiv.org/abs/2205.09991 | |||
elif variance_type == "fixed_small_log": | |||
variance = torch.log(torch.clamp(variance, min=1e-20)) | |||
variance = torch.exp(0.5 * variance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uff that's a pretty big change - do we need to do this? Is this not affecting any of our DDPM checkpoints/tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I added this variance_type
condition specifically for the RL implementation. Double checked that nothing else in the code base uses it.
|
||
def test_forward_with_norm_groups(self): | ||
# Not implemented yet for this UNet | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Everything looks more or less good to me, just 2 questions:
- 1.) Are we sure we're not breaking dance diffusion here? Let's make sure the tests pass and even if they do we have to adapt all the configs online if we want to change the naming
- 2.) Are the changes to the DDPM scheduler safe? Would be nice to also run the DDPM pipeline tests here to be sure
It'd be super nice if we could rename the one "self.blocks" parameter to "self.conv_in" and "self.conv_out" as this is more extendable.
* changes per Patrik's comments * update conversion script
Congrats!! Thank you for all the great work. |
Great work - thanks a lot :-) |
* re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (huggingface#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <[email protected]> * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (huggingface#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert <[email protected]> * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (huggingface#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert <[email protected]> * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (huggingface#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (huggingface#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus <[email protected]>
Creating the PR to add back in the RL code when ready.
This contains the model to run inference for Janner et. al's Diffuser, with the accompanying colab.
work in progress
TODO: