Created
March 11, 2024 21:07
-
-
Save johndpope/04879444d0979f244fb88c4929b989e9 to your computer and use it in GitHub Desktop.
diff user custom chatgpt prompt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
As the PyTorch , your role is to provide expert assistance on all things related to PyTorch, the open-source machine learning library. You are equipped to handle a wide range of queries, from basic introductory questions about PyTorch's functionalities to more complex topics like model optimization, troubleshooting, and implementation of advanced features. Your responses should always be clear, concise, and accurate, tailored to the user's level of expertise. When faced with unclear or incomplete queries, politely request additional information to ensure you provide the most helpful guidance. Your focus should remain strictly on PyTorch-related topics, avoiding advice on unrelated subjects. Maintain a professional yet accessible tone, simplifying complex concepts for users with different levels of understanding in PyTorch and machine learning. you recognize the related / relevant code building blocks and provide clarity to end user when relevant. any pytorch model should include assertions where relevant. any code created from white - add comments - describe intention of neural network block. the attached models.txt maybe used in code to subclass from. the pipelines can be used / referenced to implement diffusion pipelines. suggest reviewing diffusers https://github.com/huggingface/diffusers to look at appropriate models. | |
when generating training code in - use slurm - refer to train.md sample |
This file has been truncated, but you can view the full file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ..utils import deprecate | |
from .transformers.transformer_temporal import ( | |
TransformerSpatioTemporalModel, | |
TransformerTemporalModel, | |
TransformerTemporalModelOutput, | |
) | |
class TransformerTemporalModelOutput(TransformerTemporalModelOutput): | |
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead." | |
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message) | |
class TransformerTemporalModel(TransformerTemporalModel): | |
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead." | |
deprecate("TransformerTemporalModel", "0.29", deprecation_message) | |
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel): | |
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead." | |
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message) | |
from dataclasses import dataclass | |
from ..utils import BaseOutput | |
@dataclass | |
class AutoencoderKLOutput(BaseOutput): | |
""" | |
Output of AutoencoderKL encoding method. | |
Args: | |
latent_dist (`DiagonalGaussianDistribution`): | |
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. | |
`DiagonalGaussianDistribution` allows for sampling latents from the distribution. | |
""" | |
latent_dist: "DiagonalGaussianDistribution" # noqa: F821 | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch - Flax general utilities.""" | |
import re | |
import jax.numpy as jnp | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax.random import PRNGKey | |
from ..utils import logging | |
logger = logging.get_logger(__name__) | |
def rename_key(key): | |
regex = r"\w+[.]\d+" | |
pats = re.findall(regex, key) | |
for pat in pats: | |
key = key.replace(pat, "_".join(pat.split("."))) | |
return key | |
##################### | |
# PyTorch => Flax # | |
##################### | |
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 | |
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py | |
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): | |
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" | |
# conv norm or layer norm | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) | |
# rename attention layers | |
if len(pt_tuple_key) > 1: | |
for rename_from, rename_to in ( | |
("to_out_0", "proj_attn"), | |
("to_k", "key"), | |
("to_v", "value"), | |
("to_q", "query"), | |
): | |
if pt_tuple_key[-2] == rename_from: | |
weight_name = pt_tuple_key[-1] | |
weight_name = "kernel" if weight_name == "weight" else weight_name | |
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) | |
if renamed_pt_tuple_key in random_flax_state_dict: | |
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape | |
return renamed_pt_tuple_key, pt_tensor.T | |
if ( | |
any("norm" in str_ for str_ in pt_tuple_key) | |
and (pt_tuple_key[-1] == "bias") | |
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) | |
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) | |
): | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) | |
return renamed_pt_tuple_key, pt_tensor | |
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) | |
return renamed_pt_tuple_key, pt_tensor | |
# embedding | |
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: | |
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) | |
return renamed_pt_tuple_key, pt_tensor | |
# conv layer | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) | |
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: | |
pt_tensor = pt_tensor.transpose(2, 3, 1, 0) | |
return renamed_pt_tuple_key, pt_tensor | |
# linear layer | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) | |
if pt_tuple_key[-1] == "weight": | |
pt_tensor = pt_tensor.T | |
return renamed_pt_tuple_key, pt_tensor | |
# old PyTorch layer norm weight | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) | |
if pt_tuple_key[-1] == "gamma": | |
return renamed_pt_tuple_key, pt_tensor | |
# old PyTorch layer norm bias | |
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) | |
if pt_tuple_key[-1] == "beta": | |
return renamed_pt_tuple_key, pt_tensor | |
return pt_tuple_key, pt_tensor | |
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): | |
# Step 1: Convert pytorch tensor to numpy | |
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} | |
# Step 2: Since the model is stateless, get random Flax params | |
random_flax_params = flax_model.init_weights(PRNGKey(init_key)) | |
random_flax_state_dict = flatten_dict(random_flax_params) | |
flax_state_dict = {} | |
# Need to change some parameters name to match Flax names | |
for pt_key, pt_tensor in pt_state_dict.items(): | |
renamed_pt_key = rename_key(pt_key) | |
pt_tuple_key = tuple(renamed_pt_key.split(".")) | |
# Correctly rename weight parameters | |
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) | |
if flax_key in random_flax_state_dict: | |
if flax_tensor.shape != random_flax_state_dict[flax_key].shape: | |
raise ValueError( | |
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " | |
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." | |
) | |
# also add unexpected weight so that warning is thrown | |
flax_state_dict[flax_key] = jnp.asarray(flax_tensor) | |
return unflatten_dict(flax_state_dict) | |
from typing import TYPE_CHECKING | |
from ..utils import ( | |
DIFFUSERS_SLOW_IMPORT, | |
_LazyModule, | |
is_flax_available, | |
is_torch_available, | |
) | |
_import_structure = {} | |
if is_torch_available(): | |
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] | |
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] | |
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] | |
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] | |
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] | |
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] | |
_import_structure["controlnet"] = ["ControlNetModel"] | |
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] | |
_import_structure["embeddings"] = ["ImageProjection"] | |
_import_structure["modeling_utils"] = ["ModelMixin"] | |
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"] | |
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] | |
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] | |
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] | |
_import_structure["unets.unet_1d"] = ["UNet1DModel"] | |
_import_structure["unets.unet_2d"] = ["UNet2DModel"] | |
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] | |
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] | |
_import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] | |
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] | |
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] | |
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] | |
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] | |
_import_structure["unets.uvit_2d"] = ["UVit2DModel"] | |
_import_structure["vq_model"] = ["VQModel"] | |
if is_flax_available(): | |
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"] | |
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] | |
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"] | |
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: | |
if is_torch_available(): | |
from .adapter import MultiAdapter, T2IAdapter | |
from .autoencoders import ( | |
AsymmetricAutoencoderKL, | |
AutoencoderKL, | |
AutoencoderKLTemporalDecoder, | |
AutoencoderTiny, | |
ConsistencyDecoderVAE, | |
) | |
from .controlnet import ControlNetModel | |
from .embeddings import ImageProjection | |
from .modeling_utils import ModelMixin | |
from .transformers import ( | |
DualTransformer2DModel, | |
PriorTransformer, | |
T5FilmDecoder, | |
Transformer2DModel, | |
TransformerTemporalModel, | |
) | |
from .unets import ( | |
I2VGenXLUNet, | |
Kandinsky3UNet, | |
MotionAdapter, | |
StableCascadeUNet, | |
UNet1DModel, | |
UNet2DConditionModel, | |
UNet2DModel, | |
UNet3DConditionModel, | |
UNetMotionModel, | |
UNetSpatioTemporalConditionModel, | |
UVit2DModel, | |
) | |
from .vq_model import VQModel | |
if is_flax_available(): | |
from .controlnet_flax import FlaxControlNetModel | |
from .unets import FlaxUNet2DConditionModel | |
from .vae_flax import FlaxAutoencoderKL | |
else: | |
import sys | |
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) | |
from ..utils import deprecate | |
from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput | |
class Transformer2DModelOutput(Transformer2DModelOutput): | |
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead." | |
deprecate("Transformer2DModelOutput", "0.29", deprecation_message) | |
class Transformer2DModel(Transformer2DModel): | |
deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead." | |
deprecate("Transformer2DModel", "0.29", deprecation_message) | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from ..configuration_utils import ConfigMixin, register_to_config | |
from ..loaders import FromOriginalControlNetMixin | |
from ..utils import BaseOutput, logging | |
from .attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps | |
from .modeling_utils import ModelMixin | |
from .unets.unet_2d_blocks import ( | |
CrossAttnDownBlock2D, | |
DownBlock2D, | |
UNetMidBlock2D, | |
UNetMidBlock2DCrossAttn, | |
get_down_block, | |
) | |
from .unets.unet_2d_condition import UNet2DConditionModel | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
@dataclass | |
class ControlNetOutput(BaseOutput): | |
""" | |
The output of [`ControlNetModel`]. | |
Args: | |
down_block_res_samples (`tuple[torch.Tensor]`): | |
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should | |
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be | |
used to condition the original UNet's downsampling activations. | |
mid_down_block_re_sample (`torch.Tensor`): | |
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape | |
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. | |
Output can be used to condition the original UNet's middle block activation. | |
""" | |
down_block_res_samples: Tuple[torch.Tensor] | |
mid_block_res_sample: torch.Tensor | |
class ControlNetConditioningEmbedding(nn.Module): | |
""" | |
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN | |
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized | |
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the | |
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides | |
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full | |
model) to encode image-space conditions ... into feature maps ..." | |
""" | |
def __init__( | |
self, | |
conditioning_embedding_channels: int, | |
conditioning_channels: int = 3, | |
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), | |
): | |
super().__init__() | |
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) | |
self.blocks = nn.ModuleList([]) | |
for i in range(len(block_out_channels) - 1): | |
channel_in = block_out_channels[i] | |
channel_out = block_out_channels[i + 1] | |
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) | |
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) | |
self.conv_out = zero_module( | |
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) | |
) | |
def forward(self, conditioning): | |
embedding = self.conv_in(conditioning) | |
embedding = F.silu(embedding) | |
for block in self.blocks: | |
embedding = block(embedding) | |
embedding = F.silu(embedding) | |
embedding = self.conv_out(embedding) | |
return embedding | |
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): | |
""" | |
A ControlNet model. | |
Args: | |
in_channels (`int`, defaults to 4): | |
The number of channels in the input sample. | |
flip_sin_to_cos (`bool`, defaults to `True`): | |
Whether to flip the sin to cos in the time embedding. | |
freq_shift (`int`, defaults to 0): | |
The frequency shift to apply to the time embedding. | |
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): | |
The tuple of downsample blocks to use. | |
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): | |
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
layers_per_block (`int`, defaults to 2): | |
The number of layers per block. | |
downsample_padding (`int`, defaults to 1): | |
The padding to use for the downsampling convolution. | |
mid_block_scale_factor (`float`, defaults to 1): | |
The scale factor to use for the mid block. | |
act_fn (`str`, defaults to "silu"): | |
The activation function to use. | |
norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups to use for the normalization. If None, normalization and activation layers is skipped | |
in post-processing. | |
norm_eps (`float`, defaults to 1e-5): | |
The epsilon to use for the normalization. | |
cross_attention_dim (`int`, defaults to 1280): | |
The dimension of the cross attention features. | |
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], | |
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
encoder_hid_dim (`int`, *optional*, defaults to None): | |
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` | |
dimension to `cross_attention_dim`. | |
encoder_hid_dim_type (`str`, *optional*, defaults to `None`): | |
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text | |
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. | |
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): | |
The dimension of the attention heads. | |
use_linear_projection (`bool`, defaults to `False`): | |
class_embed_type (`str`, *optional*, defaults to `None`): | |
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, | |
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. | |
addition_embed_type (`str`, *optional*, defaults to `None`): | |
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or | |
"text". "text" will use the `TextTimeEmbedding` layer. | |
num_class_embeds (`int`, *optional*, defaults to 0): | |
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing | |
class conditioning with `class_embed_type` equal to `None`. | |
upcast_attention (`bool`, defaults to `False`): | |
resnet_time_scale_shift (`str`, defaults to `"default"`): | |
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. | |
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): | |
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when | |
`class_embed_type="projection"`. | |
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): | |
The channel order of conditional image. Will convert to `rgb` if it's `bgr`. | |
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): | |
The tuple of output channel for each block in the `conditioning_embedding` layer. | |
global_pool_conditions (`bool`, defaults to `False`): | |
TODO(Patrick) - unused parameter. | |
addition_embed_type_num_heads (`int`, defaults to 64): | |
The number of heads to use for the `TextTimeEmbedding` layer. | |
""" | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
in_channels: int = 4, | |
conditioning_channels: int = 3, | |
flip_sin_to_cos: bool = True, | |
freq_shift: int = 0, | |
down_block_types: Tuple[str, ...] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
), | |
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", | |
only_cross_attention: Union[bool, Tuple[bool]] = False, | |
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), | |
layers_per_block: int = 2, | |
downsample_padding: int = 1, | |
mid_block_scale_factor: float = 1, | |
act_fn: str = "silu", | |
norm_num_groups: Optional[int] = 32, | |
norm_eps: float = 1e-5, | |
cross_attention_dim: int = 1280, | |
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, | |
encoder_hid_dim: Optional[int] = None, | |
encoder_hid_dim_type: Optional[str] = None, | |
attention_head_dim: Union[int, Tuple[int, ...]] = 8, | |
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, | |
use_linear_projection: bool = False, | |
class_embed_type: Optional[str] = None, | |
addition_embed_type: Optional[str] = None, | |
addition_time_embed_dim: Optional[int] = None, | |
num_class_embeds: Optional[int] = None, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
projection_class_embeddings_input_dim: Optional[int] = None, | |
controlnet_conditioning_channel_order: str = "rgb", | |
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), | |
global_pool_conditions: bool = False, | |
addition_embed_type_num_heads: int = 64, | |
): | |
super().__init__() | |
# If `num_attention_heads` is not defined (which is the case for most models) | |
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is. | |
# The reason for this behavior is to correct for incorrectly named variables that were introduced | |
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | |
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking | |
# which is why we correct for the naming here. | |
num_attention_heads = num_attention_heads or attention_head_dim | |
# Check inputs | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
# input | |
conv_in_kernel = 3 | |
conv_in_padding = (conv_in_kernel - 1) // 2 | |
self.conv_in = nn.Conv2d( | |
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding | |
) | |
# time | |
time_embed_dim = block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
timestep_input_dim = block_out_channels[0] | |
self.time_embedding = TimestepEmbedding( | |
timestep_input_dim, | |
time_embed_dim, | |
act_fn=act_fn, | |
) | |
if encoder_hid_dim_type is None and encoder_hid_dim is not None: | |
encoder_hid_dim_type = "text_proj" | |
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) | |
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") | |
if encoder_hid_dim is None and encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." | |
) | |
if encoder_hid_dim_type == "text_proj": | |
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) | |
elif encoder_hid_dim_type == "text_image_proj": | |
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` | |
self.encoder_hid_proj = TextImageProjection( | |
text_embed_dim=encoder_hid_dim, | |
image_embed_dim=cross_attention_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." | |
) | |
else: | |
self.encoder_hid_proj = None | |
# class embedding | |
if class_embed_type is None and num_class_embeds is not None: | |
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) | |
elif class_embed_type == "timestep": | |
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) | |
elif class_embed_type == "identity": | |
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) | |
elif class_embed_type == "projection": | |
if projection_class_embeddings_input_dim is None: | |
raise ValueError( | |
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" | |
) | |
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except | |
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings | |
# 2. it projects from an arbitrary input dimension. | |
# | |
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. | |
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. | |
# As a result, `TimestepEmbedding` can be passed arbitrary vectors. | |
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
else: | |
self.class_embedding = None | |
if addition_embed_type == "text": | |
if encoder_hid_dim is not None: | |
text_time_embedding_from_dim = encoder_hid_dim | |
else: | |
text_time_embedding_from_dim = cross_attention_dim | |
self.add_embedding = TextTimeEmbedding( | |
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads | |
) | |
elif addition_embed_type == "text_image": | |
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` | |
self.add_embedding = TextImageTimeEmbedding( | |
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim | |
) | |
elif addition_embed_type == "text_time": | |
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) | |
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
elif addition_embed_type is not None: | |
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") | |
# control net conditioning embedding | |
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( | |
conditioning_embedding_channels=block_out_channels[0], | |
block_out_channels=conditioning_embedding_out_channels, | |
conditioning_channels=conditioning_channels, | |
) | |
self.down_blocks = nn.ModuleList([]) | |
self.controlnet_down_blocks = nn.ModuleList([]) | |
if isinstance(only_cross_attention, bool): | |
only_cross_attention = [only_cross_attention] * len(down_block_types) | |
if isinstance(attention_head_dim, int): | |
attention_head_dim = (attention_head_dim,) * len(down_block_types) | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
# down | |
output_channel = block_out_channels[0] | |
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) | |
controlnet_block = zero_module(controlnet_block) | |
self.controlnet_down_blocks.append(controlnet_block) | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block, | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads[i], | |
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
downsample_padding=downsample_padding, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
) | |
self.down_blocks.append(down_block) | |
for _ in range(layers_per_block): | |
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) | |
controlnet_block = zero_module(controlnet_block) | |
self.controlnet_down_blocks.append(controlnet_block) | |
if not is_final_block: | |
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) | |
controlnet_block = zero_module(controlnet_block) | |
self.controlnet_down_blocks.append(controlnet_block) | |
# mid | |
mid_block_channel = block_out_channels[-1] | |
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) | |
controlnet_block = zero_module(controlnet_block) | |
self.controlnet_mid_block = controlnet_block | |
if mid_block_type == "UNetMidBlock2DCrossAttn": | |
self.mid_block = UNetMidBlock2DCrossAttn( | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
in_channels=mid_block_channel, | |
temb_channels=time_embed_dim, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
output_scale_factor=mid_block_scale_factor, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads[-1], | |
resnet_groups=norm_num_groups, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
) | |
elif mid_block_type == "UNetMidBlock2D": | |
self.mid_block = UNetMidBlock2D( | |
in_channels=block_out_channels[-1], | |
temb_channels=time_embed_dim, | |
num_layers=0, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
output_scale_factor=mid_block_scale_factor, | |
resnet_groups=norm_num_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
add_attention=False, | |
) | |
else: | |
raise ValueError(f"unknown mid_block_type : {mid_block_type}") | |
@classmethod | |
def from_unet( | |
cls, | |
unet: UNet2DConditionModel, | |
controlnet_conditioning_channel_order: str = "rgb", | |
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), | |
load_weights_from_unet: bool = True, | |
conditioning_channels: int = 3, | |
): | |
r""" | |
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. | |
Parameters: | |
unet (`UNet2DConditionModel`): | |
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied | |
where applicable. | |
""" | |
transformer_layers_per_block = ( | |
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 | |
) | |
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None | |
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None | |
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None | |
addition_time_embed_dim = ( | |
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None | |
) | |
controlnet = cls( | |
encoder_hid_dim=encoder_hid_dim, | |
encoder_hid_dim_type=encoder_hid_dim_type, | |
addition_embed_type=addition_embed_type, | |
addition_time_embed_dim=addition_time_embed_dim, | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=unet.config.in_channels, | |
flip_sin_to_cos=unet.config.flip_sin_to_cos, | |
freq_shift=unet.config.freq_shift, | |
down_block_types=unet.config.down_block_types, | |
only_cross_attention=unet.config.only_cross_attention, | |
block_out_channels=unet.config.block_out_channels, | |
layers_per_block=unet.config.layers_per_block, | |
downsample_padding=unet.config.downsample_padding, | |
mid_block_scale_factor=unet.config.mid_block_scale_factor, | |
act_fn=unet.config.act_fn, | |
norm_num_groups=unet.config.norm_num_groups, | |
norm_eps=unet.config.norm_eps, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
attention_head_dim=unet.config.attention_head_dim, | |
num_attention_heads=unet.config.num_attention_heads, | |
use_linear_projection=unet.config.use_linear_projection, | |
class_embed_type=unet.config.class_embed_type, | |
num_class_embeds=unet.config.num_class_embeds, | |
upcast_attention=unet.config.upcast_attention, | |
resnet_time_scale_shift=unet.config.resnet_time_scale_shift, | |
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, | |
mid_block_type=unet.config.mid_block_type, | |
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, | |
conditioning_embedding_out_channels=conditioning_embedding_out_channels, | |
conditioning_channels=conditioning_channels, | |
) | |
if load_weights_from_unet: | |
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) | |
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) | |
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) | |
if controlnet.class_embedding: | |
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) | |
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) | |
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) | |
return controlnet | |
@property | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice | |
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: | |
r""" | |
Enable sliced attention computation. | |
When this option is enabled, the attention module splits the input tensor in slices to compute attention in | |
several steps. This is useful for saving some memory in exchange for a small decrease in speed. | |
Args: | |
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | |
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | |
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | |
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | |
must be a multiple of `slice_size`. | |
""" | |
sliceable_head_dims = [] | |
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): | |
if hasattr(module, "set_attention_slice"): | |
sliceable_head_dims.append(module.sliceable_head_dim) | |
for child in module.children(): | |
fn_recursive_retrieve_sliceable_dims(child) | |
# retrieve number of attention layers | |
for module in self.children(): | |
fn_recursive_retrieve_sliceable_dims(module) | |
num_sliceable_layers = len(sliceable_head_dims) | |
if slice_size == "auto": | |
# half the attention head size is usually a good trade-off between | |
# speed and memory | |
slice_size = [dim // 2 for dim in sliceable_head_dims] | |
elif slice_size == "max": | |
# make smallest slice possible | |
slice_size = num_sliceable_layers * [1] | |
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size | |
if len(slice_size) != len(sliceable_head_dims): | |
raise ValueError( | |
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" | |
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." | |
) | |
for i in range(len(slice_size)): | |
size = slice_size[i] | |
dim = sliceable_head_dims[i] | |
if size is not None and size > dim: | |
raise ValueError(f"size {size} has to be smaller or equal to {dim}.") | |
# Recursively walk through all the children. | |
# Any children which exposes the set_attention_slice method | |
# gets the message | |
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): | |
if hasattr(module, "set_attention_slice"): | |
module.set_attention_slice(slice_size.pop()) | |
for child in module.children(): | |
fn_recursive_set_attention_slice(child, slice_size) | |
reversed_slice_size = list(reversed(slice_size)) | |
for module in self.children(): | |
fn_recursive_set_attention_slice(module, reversed_slice_size) | |
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: | |
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): | |
module.gradient_checkpointing = value | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
controlnet_cond: torch.FloatTensor, | |
conditioning_scale: float = 1.0, | |
class_labels: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
guess_mode: bool = False, | |
return_dict: bool = True, | |
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: | |
""" | |
The [`ControlNetModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor. | |
timestep (`Union[torch.Tensor, float, int]`): | |
The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.Tensor`): | |
The encoder hidden states. | |
controlnet_cond (`torch.FloatTensor`): | |
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. | |
conditioning_scale (`float`, defaults to `1.0`): | |
The scale factor for ControlNet outputs. | |
class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): | |
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the | |
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep | |
embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
added_cond_kwargs (`dict`): | |
Additional conditions for the Stable Diffusion XL UNet. | |
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): | |
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. | |
guess_mode (`bool`, defaults to `False`): | |
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if | |
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. | |
return_dict (`bool`, defaults to `True`): | |
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.controlnet.ControlNetOutput`] **or** `tuple`: | |
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is | |
returned where the first element is the sample tensor. | |
""" | |
# check channel order | |
channel_order = self.config.controlnet_conditioning_channel_order | |
if channel_order == "rgb": | |
# in rgb order by default | |
... | |
elif channel_order == "bgr": | |
controlnet_cond = torch.flip(controlnet_cond, dims=[1]) | |
else: | |
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") | |
# prepare attention_mask | |
if attention_mask is not None: | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
aug_emb = None | |
if self.class_embedding is not None: | |
if class_labels is None: | |
raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
if self.config.class_embed_type == "timestep": | |
class_labels = self.time_proj(class_labels) | |
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) | |
emb = emb + class_emb | |
if self.config.addition_embed_type is not None: | |
if self.config.addition_embed_type == "text": | |
aug_emb = self.add_embedding(encoder_hidden_states) | |
elif self.config.addition_embed_type == "text_time": | |
if "text_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
) | |
text_embeds = added_cond_kwargs.get("text_embeds") | |
if "time_ids" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
) | |
time_ids = added_cond_kwargs.get("time_ids") | |
time_embeds = self.add_time_proj(time_ids.flatten()) | |
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
add_embeds = add_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(add_embeds) | |
emb = emb + aug_emb if aug_emb is not None else emb | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) | |
sample = sample + controlnet_cond | |
# 3. down | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
down_block_res_samples += res_samples | |
# 4. mid | |
if self.mid_block is not None: | |
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: | |
sample = self.mid_block( | |
sample, | |
emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
else: | |
sample = self.mid_block(sample, emb) | |
# 5. Control net blocks | |
controlnet_down_block_res_samples = () | |
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): | |
down_block_res_sample = controlnet_block(down_block_res_sample) | |
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) | |
down_block_res_samples = controlnet_down_block_res_samples | |
mid_block_res_sample = self.controlnet_mid_block(sample) | |
# 6. scaling | |
if guess_mode and not self.config.global_pool_conditions: | |
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 | |
scales = scales * conditioning_scale | |
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] | |
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one | |
else: | |
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] | |
mid_block_res_sample = mid_block_res_sample * conditioning_scale | |
if self.config.global_pool_conditions: | |
down_block_res_samples = [ | |
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples | |
] | |
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) | |
if not return_dict: | |
return (down_block_res_samples, mid_block_res_sample) | |
return ControlNetOutput( | |
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample | |
) | |
def zero_module(module): | |
for p in module.parameters(): | |
nn.init.zeros_(p) | |
return module | |
import math | |
import flax.linen as nn | |
import jax.numpy as jnp | |
def get_sinusoidal_embeddings( | |
timesteps: jnp.ndarray, | |
embedding_dim: int, | |
freq_shift: float = 1, | |
min_timescale: float = 1, | |
max_timescale: float = 1.0e4, | |
flip_sin_to_cos: bool = False, | |
scale: float = 1.0, | |
) -> jnp.ndarray: | |
"""Returns the positional encoding (same as Tensor2Tensor). | |
Args: | |
timesteps: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
embedding_dim: The number of output channels. | |
min_timescale: The smallest time unit (should probably be 0.0). | |
max_timescale: The largest time unit. | |
Returns: | |
a Tensor of timing signals [N, num_channels] | |
""" | |
assert timesteps.ndim == 1, "Timesteps should be a 1d-array" | |
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" | |
num_timescales = float(embedding_dim // 2) | |
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) | |
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) | |
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) | |
# scale embeddings | |
scaled_time = scale * emb | |
if flip_sin_to_cos: | |
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) | |
else: | |
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) | |
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) | |
return signal | |
class FlaxTimestepEmbedding(nn.Module): | |
r""" | |
Time step Embedding Module. Learns embeddings for input time steps. | |
Args: | |
time_embed_dim (`int`, *optional*, defaults to `32`): | |
Time step embedding dimension | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
time_embed_dim: int = 32 | |
dtype: jnp.dtype = jnp.float32 | |
@nn.compact | |
def __call__(self, temb): | |
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) | |
temb = nn.silu(temb) | |
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) | |
return temb | |
class FlaxTimesteps(nn.Module): | |
r""" | |
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 | |
Args: | |
dim (`int`, *optional*, defaults to `32`): | |
Time step embedding dimension | |
""" | |
dim: int = 32 | |
flip_sin_to_cos: bool = False | |
freq_shift: float = 1 | |
@nn.compact | |
def __call__(self, timesteps): | |
return get_sinusoidal_embeddings( | |
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift | |
) | |
from dataclasses import dataclass | |
from typing import Any, Dict, Optional | |
import torch | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import BaseOutput | |
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock | |
from ..embeddings import TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
from ..resnet import AlphaBlender | |
@dataclass | |
class TransformerTemporalModelOutput(BaseOutput): | |
""" | |
The output of [`TransformerTemporalModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. | |
""" | |
sample: torch.FloatTensor | |
class TransformerTemporalModel(ModelMixin, ConfigMixin): | |
""" | |
A Transformer model for video-like data. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. | |
in_channels (`int`, *optional*): | |
The number of channels in the input and output (specify if the input is **continuous**). | |
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. | |
attention_bias (`bool`, *optional*): | |
Configure if the `TransformerBlock` attention should contain a bias parameter. | |
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). | |
This is fixed during training since it is used to learn a number of position embeddings. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): | |
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported | |
activation functions. | |
norm_elementwise_affine (`bool`, *optional*): | |
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. | |
double_self_attention (`bool`, *optional*): | |
Configure if each `TransformerBlock` should contain two self-attention layers. | |
positional_embeddings: (`str`, *optional*): | |
The type of positional embeddings to apply to the sequence input before passing use. | |
num_positional_embeddings: (`int`, *optional*): | |
The maximum length of the sequence over which to apply positional embeddings. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
dropout: float = 0.0, | |
norm_num_groups: int = 32, | |
cross_attention_dim: Optional[int] = None, | |
attention_bias: bool = False, | |
sample_size: Optional[int] = None, | |
activation_fn: str = "geglu", | |
norm_elementwise_affine: bool = True, | |
double_self_attention: bool = True, | |
positional_embeddings: Optional[str] = None, | |
num_positional_embeddings: Optional[int] = None, | |
): | |
super().__init__() | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
inner_dim = num_attention_heads * attention_head_dim | |
self.in_channels = in_channels | |
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
# 3. Define transformers blocks | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
dropout=dropout, | |
cross_attention_dim=cross_attention_dim, | |
activation_fn=activation_fn, | |
attention_bias=attention_bias, | |
double_self_attention=double_self_attention, | |
norm_elementwise_affine=norm_elementwise_affine, | |
positional_embeddings=positional_embeddings, | |
num_positional_embeddings=num_positional_embeddings, | |
) | |
for d in range(num_layers) | |
] | |
) | |
self.proj_out = nn.Linear(inner_dim, in_channels) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.LongTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
class_labels: torch.LongTensor = None, | |
num_frames: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
return_dict: bool = True, | |
) -> TransformerTemporalModelOutput: | |
""" | |
The [`TransformerTemporal`] forward method. | |
Args: | |
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): | |
Input hidden_states. | |
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): | |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
self-attention. | |
timestep ( `torch.LongTensor`, *optional*): | |
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. | |
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | |
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in | |
`AdaLayerZeroNorm`. | |
num_frames (`int`, *optional*, defaults to 1): | |
The number of frames to be processed per batch. This is used to reshape the hidden states. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is | |
returned, otherwise a `tuple` where the first element is the sample tensor. | |
""" | |
# 1. Input | |
batch_frames, channel, height, width = hidden_states.shape | |
batch_size = batch_frames // num_frames | |
residual = hidden_states | |
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) | |
hidden_states = hidden_states.permute(0, 2, 1, 3, 4) | |
hidden_states = self.norm(hidden_states) | |
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) | |
hidden_states = self.proj_in(hidden_states) | |
# 2. Blocks | |
for block in self.transformer_blocks: | |
hidden_states = block( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
timestep=timestep, | |
cross_attention_kwargs=cross_attention_kwargs, | |
class_labels=class_labels, | |
) | |
# 3. Output | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = ( | |
hidden_states[None, None, :] | |
.reshape(batch_size, height, width, num_frames, channel) | |
.permute(0, 3, 4, 1, 2) | |
.contiguous() | |
) | |
hidden_states = hidden_states.reshape(batch_frames, channel, height, width) | |
output = hidden_states + residual | |
if not return_dict: | |
return (output,) | |
return TransformerTemporalModelOutput(sample=output) | |
class TransformerSpatioTemporalModel(nn.Module): | |
""" | |
A Transformer model for video-like data. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. | |
in_channels (`int`, *optional*): | |
The number of channels in the input and output (specify if the input is **continuous**). | |
out_channels (`int`, *optional*): | |
The number of channels in the output (specify if the input is **continuous**). | |
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | |
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. | |
""" | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: int = 320, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
cross_attention_dim: Optional[int] = None, | |
): | |
super().__init__() | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
inner_dim = num_attention_heads * attention_head_dim | |
self.inner_dim = inner_dim | |
# 2. Define input layers | |
self.in_channels = in_channels | |
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
# 3. Define transformers blocks | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
for d in range(num_layers) | |
] | |
) | |
time_mix_inner_dim = inner_dim | |
self.temporal_transformer_blocks = nn.ModuleList( | |
[ | |
TemporalBasicTransformerBlock( | |
inner_dim, | |
time_mix_inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
time_embed_dim = in_channels * 4 | |
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) | |
self.time_proj = Timesteps(in_channels, True, 0) | |
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") | |
# 4. Define output layers | |
self.out_channels = in_channels if out_channels is None else out_channels | |
# TODO: should use out_channels for continuous projections | |
self.proj_out = nn.Linear(inner_dim, in_channels) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
): | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): | |
Input hidden_states. | |
num_frames (`int`): | |
The number of frames to be processed per batch. This is used to reshape the hidden states. | |
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): | |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
self-attention. | |
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): | |
A tensor indicating whether the input contains only images. 1 indicates that the input contains only | |
images, 0 indicates that the input contains video frames. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is | |
returned, otherwise a `tuple` where the first element is the sample tensor. | |
""" | |
# 1. Input | |
batch_frames, _, height, width = hidden_states.shape | |
num_frames = image_only_indicator.shape[-1] | |
batch_size = batch_frames // num_frames | |
time_context = encoder_hidden_states | |
time_context_first_timestep = time_context[None, :].reshape( | |
batch_size, num_frames, -1, time_context.shape[-1] | |
)[:, 0] | |
time_context = time_context_first_timestep[None, :].broadcast_to( | |
height * width, batch_size, 1, time_context.shape[-1] | |
) | |
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) | |
residual = hidden_states | |
hidden_states = self.norm(hidden_states) | |
inner_dim = hidden_states.shape[1] | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) | |
hidden_states = self.proj_in(hidden_states) | |
num_frames_emb = torch.arange(num_frames, device=hidden_states.device) | |
num_frames_emb = num_frames_emb.repeat(batch_size, 1) | |
num_frames_emb = num_frames_emb.reshape(-1) | |
t_emb = self.time_proj(num_frames_emb) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=hidden_states.dtype) | |
emb = self.time_pos_embed(t_emb) | |
emb = emb[:, None, :] | |
# 2. Blocks | |
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): | |
if self.training and self.gradient_checkpointing: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
block, | |
hidden_states, | |
None, | |
encoder_hidden_states, | |
None, | |
use_reentrant=False, | |
) | |
else: | |
hidden_states = block( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
) | |
hidden_states_mix = hidden_states | |
hidden_states_mix = hidden_states_mix + emb | |
hidden_states_mix = temporal_block( | |
hidden_states_mix, | |
num_frames=num_frames, | |
encoder_hidden_states=time_context, | |
) | |
hidden_states = self.time_mixer( | |
x_spatial=hidden_states, | |
x_temporal=hidden_states_mix, | |
image_only_indicator=image_only_indicator, | |
) | |
# 3. Output | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
output = hidden_states + residual | |
if not return_dict: | |
return (output,) | |
return TransformerTemporalModelOutput(sample=output) | |
from ...utils import is_torch_available | |
if is_torch_available(): | |
from .dual_transformer_2d import DualTransformer2DModel | |
from .prior_transformer import PriorTransformer | |
from .t5_film_transformer import T5FilmDecoder | |
from .transformer_2d import Transformer2DModel | |
from .transformer_temporal import TransformerTemporalModel | |
from dataclasses import dataclass | |
from typing import Any, Dict, Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version | |
from ..attention import BasicTransformerBlock | |
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection | |
from ..lora import LoRACompatibleConv, LoRACompatibleLinear | |
from ..modeling_utils import ModelMixin | |
from ..normalization import AdaLayerNormSingle | |
@dataclass | |
class Transformer2DModelOutput(BaseOutput): | |
""" | |
The output of [`Transformer2DModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): | |
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability | |
distributions for the unnoised latent pixels. | |
""" | |
sample: torch.FloatTensor | |
class Transformer2DModel(ModelMixin, ConfigMixin): | |
""" | |
A 2D Transformer model for image-like data. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. | |
in_channels (`int`, *optional*): | |
The number of channels in the input and output (specify if the input is **continuous**). | |
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. | |
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). | |
This is fixed during training since it is used to learn a number of position embeddings. | |
num_vector_embeds (`int`, *optional*): | |
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). | |
Includes the class for the masked latent pixel. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. | |
num_embeds_ada_norm ( `int`, *optional*): | |
The number of diffusion steps used during training. Pass if at least one of the norm_layers is | |
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are | |
added to the hidden states. | |
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. | |
attention_bias (`bool`, *optional*): | |
Configure if the `TransformerBlocks` attention should contain a bias parameter. | |
""" | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
dropout: float = 0.0, | |
norm_num_groups: int = 32, | |
cross_attention_dim: Optional[int] = None, | |
attention_bias: bool = False, | |
sample_size: Optional[int] = None, | |
num_vector_embeds: Optional[int] = None, | |
patch_size: Optional[int] = None, | |
activation_fn: str = "geglu", | |
num_embeds_ada_norm: Optional[int] = None, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
double_self_attention: bool = False, | |
upcast_attention: bool = False, | |
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' | |
norm_elementwise_affine: bool = True, | |
norm_eps: float = 1e-5, | |
attention_type: str = "default", | |
caption_channels: int = None, | |
interpolation_scale: float = None, | |
): | |
super().__init__() | |
if patch_size is not None: | |
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: | |
raise NotImplementedError( | |
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." | |
) | |
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: | |
raise ValueError( | |
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." | |
) | |
self.use_linear_projection = use_linear_projection | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
inner_dim = num_attention_heads * attention_head_dim | |
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | |
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | |
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` | |
# Define whether input is continuous or discrete depending on configuration | |
self.is_input_continuous = (in_channels is not None) and (patch_size is None) | |
self.is_input_vectorized = num_vector_embeds is not None | |
self.is_input_patches = in_channels is not None and patch_size is not None | |
if norm_type == "layer_norm" and num_embeds_ada_norm is not None: | |
deprecation_message = ( | |
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" | |
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." | |
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" | |
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" | |
" would be very nice if you could open a Pull request for the `transformer/config.json` file" | |
) | |
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) | |
norm_type = "ada_norm" | |
if self.is_input_continuous and self.is_input_vectorized: | |
raise ValueError( | |
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" | |
" sure that either `in_channels` or `num_vector_embeds` is None." | |
) | |
elif self.is_input_vectorized and self.is_input_patches: | |
raise ValueError( | |
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" | |
" sure that either `num_vector_embeds` or `num_patches` is None." | |
) | |
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: | |
raise ValueError( | |
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" | |
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." | |
) | |
# 2. Define input layers | |
if self.is_input_continuous: | |
self.in_channels = in_channels | |
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
if use_linear_projection: | |
self.proj_in = linear_cls(in_channels, inner_dim) | |
else: | |
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) | |
elif self.is_input_vectorized: | |
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" | |
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" | |
self.height = sample_size | |
self.width = sample_size | |
self.num_vector_embeds = num_vector_embeds | |
self.num_latent_pixels = self.height * self.width | |
self.latent_image_embedding = ImagePositionalEmbeddings( | |
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width | |
) | |
elif self.is_input_patches: | |
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" | |
self.height = sample_size | |
self.width = sample_size | |
self.patch_size = patch_size | |
interpolation_scale = ( | |
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1) | |
) | |
self.pos_embed = PatchEmbed( | |
height=sample_size, | |
width=sample_size, | |
patch_size=patch_size, | |
in_channels=in_channels, | |
embed_dim=inner_dim, | |
interpolation_scale=interpolation_scale, | |
) | |
# 3. Define transformers blocks | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
dropout=dropout, | |
cross_attention_dim=cross_attention_dim, | |
activation_fn=activation_fn, | |
num_embeds_ada_norm=num_embeds_ada_norm, | |
attention_bias=attention_bias, | |
only_cross_attention=only_cross_attention, | |
double_self_attention=double_self_attention, | |
upcast_attention=upcast_attention, | |
norm_type=norm_type, | |
norm_elementwise_affine=norm_elementwise_affine, | |
norm_eps=norm_eps, | |
attention_type=attention_type, | |
) | |
for d in range(num_layers) | |
] | |
) | |
# 4. Define output layers | |
self.out_channels = in_channels if out_channels is None else out_channels | |
if self.is_input_continuous: | |
# TODO: should use out_channels for continuous projections | |
if use_linear_projection: | |
self.proj_out = linear_cls(inner_dim, in_channels) | |
else: | |
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) | |
elif self.is_input_vectorized: | |
self.norm_out = nn.LayerNorm(inner_dim) | |
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) | |
elif self.is_input_patches and norm_type != "ada_norm_single": | |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) | |
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) | |
elif self.is_input_patches and norm_type == "ada_norm_single": | |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) | |
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) | |
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) | |
# 5. PixArt-Alpha blocks. | |
self.adaln_single = None | |
self.use_additional_conditions = False | |
if norm_type == "ada_norm_single": | |
self.use_additional_conditions = self.config.sample_size == 128 | |
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use | |
# additional conditions until we find better name | |
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) | |
self.caption_projection = None | |
if caption_channels is not None: | |
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) | |
self.gradient_checkpointing = False | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
added_cond_kwargs: Dict[str, torch.Tensor] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
): | |
""" | |
The [`Transformer2DModel`] forward method. | |
Args: | |
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): | |
Input `hidden_states`. | |
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
self-attention. | |
timestep ( `torch.LongTensor`, *optional*): | |
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. | |
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | |
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in | |
`AdaLayerZeroNorm`. | |
cross_attention_kwargs ( `Dict[str, Any]`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
attention_mask ( `torch.Tensor`, *optional*): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
encoder_attention_mask ( `torch.Tensor`, *optional*): | |
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: | |
* Mask `(batch, sequence_length)` True = keep, False = discard. | |
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. | |
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format | |
above. This bias will be added to the cross-attention scores. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a | |
`tuple` where the first element is the sample tensor. | |
""" | |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. | |
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. | |
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. | |
# expects mask of shape: | |
# [batch, key_tokens] | |
# adds singleton query_tokens dimension: | |
# [batch, 1, key_tokens] | |
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
if attention_mask is not None and attention_mask.ndim == 2: | |
# assume that mask is expressed as: | |
# (1 = keep, 0 = discard) | |
# convert mask into a bias that can be added to attention scores: | |
# (keep = +0, discard = -10000.0) | |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
# Retrieve lora scale. | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
# 1. Input | |
if self.is_input_continuous: | |
batch, _, height, width = hidden_states.shape | |
residual = hidden_states | |
hidden_states = self.norm(hidden_states) | |
if not self.use_linear_projection: | |
hidden_states = ( | |
self.proj_in(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_in(hidden_states) | |
) | |
inner_dim = hidden_states.shape[1] | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
else: | |
inner_dim = hidden_states.shape[1] | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
hidden_states = ( | |
self.proj_in(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_in(hidden_states) | |
) | |
elif self.is_input_vectorized: | |
hidden_states = self.latent_image_embedding(hidden_states) | |
elif self.is_input_patches: | |
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size | |
hidden_states = self.pos_embed(hidden_states) | |
if self.adaln_single is not None: | |
if self.use_additional_conditions and added_cond_kwargs is None: | |
raise ValueError( | |
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." | |
) | |
batch_size = hidden_states.shape[0] | |
timestep, embedded_timestep = self.adaln_single( | |
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype | |
) | |
# 2. Blocks | |
if self.caption_projection is not None: | |
batch_size = hidden_states.shape[0] | |
encoder_hidden_states = self.caption_projection(encoder_hidden_states) | |
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) | |
for block in self.transformer_blocks: | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
attention_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
timestep, | |
cross_attention_kwargs, | |
class_labels, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = block( | |
hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
timestep=timestep, | |
cross_attention_kwargs=cross_attention_kwargs, | |
class_labels=class_labels, | |
) | |
# 3. Output | |
if self.is_input_continuous: | |
if not self.use_linear_projection: | |
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
hidden_states = ( | |
self.proj_out(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_out(hidden_states) | |
) | |
else: | |
hidden_states = ( | |
self.proj_out(hidden_states, scale=lora_scale) | |
if not USE_PEFT_BACKEND | |
else self.proj_out(hidden_states) | |
) | |
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
output = hidden_states + residual | |
elif self.is_input_vectorized: | |
hidden_states = self.norm_out(hidden_states) | |
logits = self.out(hidden_states) | |
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels) | |
logits = logits.permute(0, 2, 1) | |
# log(p(x_0)) | |
output = F.log_softmax(logits.double(), dim=1).float() | |
if self.is_input_patches: | |
if self.config.norm_type != "ada_norm_single": | |
conditioning = self.transformer_blocks[0].norm1.emb( | |
timestep, class_labels, hidden_dtype=hidden_states.dtype | |
) | |
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) | |
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] | |
hidden_states = self.proj_out_2(hidden_states) | |
elif self.config.norm_type == "ada_norm_single": | |
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) | |
hidden_states = self.norm_out(hidden_states) | |
# Modulation | |
hidden_states = hidden_states * (1 + scale) + shift | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = hidden_states.squeeze(1) | |
# unpatchify | |
if self.adaln_single is None: | |
height = width = int(hidden_states.shape[1] ** 0.5) | |
hidden_states = hidden_states.reshape( | |
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) | |
) | |
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) | |
output = hidden_states.reshape( | |
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) | |
) | |
if not return_dict: | |
return (output,) | |
return Transformer2DModelOutput(sample=output) | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin | |
from ...utils import BaseOutput | |
from ..attention import BasicTransformerBlock | |
from ..attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from ..embeddings import TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
@dataclass | |
class PriorTransformerOutput(BaseOutput): | |
""" | |
The output of [`PriorTransformer`]. | |
Args: | |
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | |
The predicted CLIP image embedding conditioned on the CLIP text embedding input. | |
""" | |
predicted_image_embedding: torch.FloatTensor | |
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): | |
""" | |
A Prior Transformer model. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. | |
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. | |
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` | |
num_embeddings (`int`, *optional*, defaults to 77): | |
The number of embeddings of the model input `hidden_states` | |
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the | |
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + | |
additional_embeddings`. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
time_embed_act_fn (`str`, *optional*, defaults to 'silu'): | |
The activation function to use to create timestep embeddings. | |
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before | |
passing to Transformer blocks. Set it to `None` if normalization is not needed. | |
embedding_proj_norm_type (`str`, *optional*, defaults to None): | |
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not | |
needed. | |
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): | |
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if | |
`encoder_hidden_states` is `None`. | |
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. | |
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot | |
product between the text embedding and image embedding as proposed in the unclip paper | |
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. | |
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. | |
If None, will be set to `num_attention_heads * attention_head_dim` | |
embedding_proj_dim (`int`, *optional*, default to None): | |
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. | |
clip_embed_dim (`int`, *optional*, default to None): | |
The dimension of the output. If None, will be set to `embedding_dim`. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
num_attention_heads: int = 32, | |
attention_head_dim: int = 64, | |
num_layers: int = 20, | |
embedding_dim: int = 768, | |
num_embeddings=77, | |
additional_embeddings=4, | |
dropout: float = 0.0, | |
time_embed_act_fn: str = "silu", | |
norm_in_type: Optional[str] = None, # layer | |
embedding_proj_norm_type: Optional[str] = None, # layer | |
encoder_hid_proj_type: Optional[str] = "linear", # linear | |
added_emb_type: Optional[str] = "prd", # prd | |
time_embed_dim: Optional[int] = None, | |
embedding_proj_dim: Optional[int] = None, | |
clip_embed_dim: Optional[int] = None, | |
): | |
super().__init__() | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
inner_dim = num_attention_heads * attention_head_dim | |
self.additional_embeddings = additional_embeddings | |
time_embed_dim = time_embed_dim or inner_dim | |
embedding_proj_dim = embedding_proj_dim or embedding_dim | |
clip_embed_dim = clip_embed_dim or embedding_dim | |
self.time_proj = Timesteps(inner_dim, True, 0) | |
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) | |
self.proj_in = nn.Linear(embedding_dim, inner_dim) | |
if embedding_proj_norm_type is None: | |
self.embedding_proj_norm = None | |
elif embedding_proj_norm_type == "layer": | |
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) | |
else: | |
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") | |
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) | |
if encoder_hid_proj_type is None: | |
self.encoder_hidden_states_proj = None | |
elif encoder_hid_proj_type == "linear": | |
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) | |
else: | |
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") | |
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) | |
if added_emb_type == "prd": | |
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) | |
elif added_emb_type is None: | |
self.prd_embedding = None | |
else: | |
raise ValueError( | |
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." | |
) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
dropout=dropout, | |
activation_fn="gelu", | |
attention_bias=True, | |
) | |
for d in range(num_layers) | |
] | |
) | |
if norm_in_type == "layer": | |
self.norm_in = nn.LayerNorm(inner_dim) | |
elif norm_in_type is None: | |
self.norm_in = None | |
else: | |
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") | |
self.norm_out = nn.LayerNorm(inner_dim) | |
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) | |
causal_attention_mask = torch.full( | |
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 | |
) | |
causal_attention_mask.triu_(1) | |
causal_attention_mask = causal_attention_mask[None, ...] | |
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) | |
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) | |
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) | |
@property | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def forward( | |
self, | |
hidden_states, | |
timestep: Union[torch.Tensor, float, int], | |
proj_embedding: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.BoolTensor] = None, | |
return_dict: bool = True, | |
): | |
""" | |
The [`PriorTransformer`] forward method. | |
Args: | |
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | |
The currently predicted image embeddings. | |
timestep (`torch.LongTensor`): | |
Current denoising step. | |
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | |
Projected embedding vector the denoising process is conditioned on. | |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): | |
Hidden states of the text embeddings the denoising process is conditioned on. | |
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): | |
Text mask for the text embeddings. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: | |
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a | |
tuple is returned where the first element is the sample tensor. | |
""" | |
batch_size = hidden_states.shape[0] | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) | |
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(hidden_states.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) | |
timesteps_projected = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might be fp16, so we need to cast here. | |
timesteps_projected = timesteps_projected.to(dtype=self.dtype) | |
time_embeddings = self.time_embedding(timesteps_projected) | |
if self.embedding_proj_norm is not None: | |
proj_embedding = self.embedding_proj_norm(proj_embedding) | |
proj_embeddings = self.embedding_proj(proj_embedding) | |
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: | |
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) | |
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: | |
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") | |
hidden_states = self.proj_in(hidden_states) | |
positional_embeddings = self.positional_embedding.to(hidden_states.dtype) | |
additional_embeds = [] | |
additional_embeddings_len = 0 | |
if encoder_hidden_states is not None: | |
additional_embeds.append(encoder_hidden_states) | |
additional_embeddings_len += encoder_hidden_states.shape[1] | |
if len(proj_embeddings.shape) == 2: | |
proj_embeddings = proj_embeddings[:, None, :] | |
if len(hidden_states.shape) == 2: | |
hidden_states = hidden_states[:, None, :] | |
additional_embeds = additional_embeds + [ | |
proj_embeddings, | |
time_embeddings[:, None, :], | |
hidden_states, | |
] | |
if self.prd_embedding is not None: | |
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) | |
additional_embeds.append(prd_embedding) | |
hidden_states = torch.cat( | |
additional_embeds, | |
dim=1, | |
) | |
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens | |
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 | |
if positional_embeddings.shape[1] < hidden_states.shape[1]: | |
positional_embeddings = F.pad( | |
positional_embeddings, | |
( | |
0, | |
0, | |
additional_embeddings_len, | |
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, | |
), | |
value=0.0, | |
) | |
hidden_states = hidden_states + positional_embeddings | |
if attention_mask is not None: | |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | |
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) | |
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) | |
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) | |
if self.norm_in is not None: | |
hidden_states = self.norm_in(hidden_states) | |
for block in self.transformer_blocks: | |
hidden_states = block(hidden_states, attention_mask=attention_mask) | |
hidden_states = self.norm_out(hidden_states) | |
if self.prd_embedding is not None: | |
hidden_states = hidden_states[:, -1] | |
else: | |
hidden_states = hidden_states[:, additional_embeddings_len:] | |
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) | |
if not return_dict: | |
return (predicted_image_embedding,) | |
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) | |
def post_process_latents(self, prior_latents): | |
prior_latents = (prior_latents * self.clip_std) + self.clip_mean | |
return prior_latents | |
import math | |
from typing import Optional, Tuple | |
import torch | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ..attention_processor import Attention | |
from ..embeddings import get_timestep_embedding | |
from ..modeling_utils import ModelMixin | |
class T5FilmDecoder(ModelMixin, ConfigMixin): | |
r""" | |
T5 style decoder with FiLM conditioning. | |
Args: | |
input_dims (`int`, *optional*, defaults to `128`): | |
The number of input dimensions. | |
targets_length (`int`, *optional*, defaults to `256`): | |
The length of the targets. | |
d_model (`int`, *optional*, defaults to `768`): | |
Size of the input hidden states. | |
num_layers (`int`, *optional*, defaults to `12`): | |
The number of `DecoderLayer`'s to use. | |
num_heads (`int`, *optional*, defaults to `12`): | |
The number of attention heads to use. | |
d_kv (`int`, *optional*, defaults to `64`): | |
Size of the key-value projection vectors. | |
d_ff (`int`, *optional*, defaults to `2048`): | |
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. | |
dropout_rate (`float`, *optional*, defaults to `0.1`): | |
Dropout probability. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
input_dims: int = 128, | |
targets_length: int = 256, | |
max_decoder_noise_time: float = 2000.0, | |
d_model: int = 768, | |
num_layers: int = 12, | |
num_heads: int = 12, | |
d_kv: int = 64, | |
d_ff: int = 2048, | |
dropout_rate: float = 0.1, | |
): | |
super().__init__() | |
self.conditioning_emb = nn.Sequential( | |
nn.Linear(d_model, d_model * 4, bias=False), | |
nn.SiLU(), | |
nn.Linear(d_model * 4, d_model * 4, bias=False), | |
nn.SiLU(), | |
) | |
self.position_encoding = nn.Embedding(targets_length, d_model) | |
self.position_encoding.weight.requires_grad = False | |
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) | |
self.dropout = nn.Dropout(p=dropout_rate) | |
self.decoders = nn.ModuleList() | |
for lyr_num in range(num_layers): | |
# FiLM conditional T5 decoder | |
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) | |
self.decoders.append(lyr) | |
self.decoder_norm = T5LayerNorm(d_model) | |
self.post_dropout = nn.Dropout(p=dropout_rate) | |
self.spec_out = nn.Linear(d_model, input_dims, bias=False) | |
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor: | |
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) | |
return mask.unsqueeze(-3) | |
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): | |
batch, _, _ = decoder_input_tokens.shape | |
assert decoder_noise_time.shape == (batch,) | |
# decoder_noise_time is in [0, 1), so rescale to expected timing range. | |
time_steps = get_timestep_embedding( | |
decoder_noise_time * self.config.max_decoder_noise_time, | |
embedding_dim=self.config.d_model, | |
max_period=self.config.max_decoder_noise_time, | |
).to(dtype=self.dtype) | |
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) | |
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) | |
seq_length = decoder_input_tokens.shape[1] | |
# If we want to use relative positions for audio context, we can just offset | |
# this sequence by the length of encodings_and_masks. | |
decoder_positions = torch.broadcast_to( | |
torch.arange(seq_length, device=decoder_input_tokens.device), | |
(batch, seq_length), | |
) | |
position_encodings = self.position_encoding(decoder_positions) | |
inputs = self.continuous_inputs_projection(decoder_input_tokens) | |
inputs += position_encodings | |
y = self.dropout(inputs) | |
# decoder: No padding present. | |
decoder_mask = torch.ones( | |
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype | |
) | |
# Translate encoding masks to encoder-decoder masks. | |
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] | |
# cross attend style: concat encodings | |
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) | |
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) | |
for lyr in self.decoders: | |
y = lyr( | |
y, | |
conditioning_emb=conditioning_emb, | |
encoder_hidden_states=encoded, | |
encoder_attention_mask=encoder_decoder_mask, | |
)[0] | |
y = self.decoder_norm(y) | |
y = self.post_dropout(y) | |
spec_out = self.spec_out(y) | |
return spec_out | |
class DecoderLayer(nn.Module): | |
r""" | |
T5 decoder layer. | |
Args: | |
d_model (`int`): | |
Size of the input hidden states. | |
d_kv (`int`): | |
Size of the key-value projection vectors. | |
num_heads (`int`): | |
Number of attention heads. | |
d_ff (`int`): | |
Size of the intermediate feed-forward layer. | |
dropout_rate (`float`): | |
Dropout probability. | |
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): | |
A small value used for numerical stability to avoid dividing by zero. | |
""" | |
def __init__( | |
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 | |
): | |
super().__init__() | |
self.layer = nn.ModuleList() | |
# cond self attention: layer 0 | |
self.layer.append( | |
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) | |
) | |
# cross attention: layer 1 | |
self.layer.append( | |
T5LayerCrossAttention( | |
d_model=d_model, | |
d_kv=d_kv, | |
num_heads=num_heads, | |
dropout_rate=dropout_rate, | |
layer_norm_epsilon=layer_norm_epsilon, | |
) | |
) | |
# Film Cond MLP + dropout: last layer | |
self.layer.append( | |
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) | |
) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
conditioning_emb: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
encoder_decoder_position_bias=None, | |
) -> Tuple[torch.FloatTensor]: | |
hidden_states = self.layer[0]( | |
hidden_states, | |
conditioning_emb=conditioning_emb, | |
attention_mask=attention_mask, | |
) | |
if encoder_hidden_states is not None: | |
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( | |
encoder_hidden_states.dtype | |
) | |
hidden_states = self.layer[1]( | |
hidden_states, | |
key_value_states=encoder_hidden_states, | |
attention_mask=encoder_extended_attention_mask, | |
) | |
# Apply Film Conditional Feed Forward layer | |
hidden_states = self.layer[-1](hidden_states, conditioning_emb) | |
return (hidden_states,) | |
class T5LayerSelfAttentionCond(nn.Module): | |
r""" | |
T5 style self-attention layer with conditioning. | |
Args: | |
d_model (`int`): | |
Size of the input hidden states. | |
d_kv (`int`): | |
Size of the key-value projection vectors. | |
num_heads (`int`): | |
Number of attention heads. | |
dropout_rate (`float`): | |
Dropout probability. | |
""" | |
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): | |
super().__init__() | |
self.layer_norm = T5LayerNorm(d_model) | |
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) | |
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
conditioning_emb: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
# pre_self_attention_layer_norm | |
normed_hidden_states = self.layer_norm(hidden_states) | |
if conditioning_emb is not None: | |
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) | |
# Self-attention block | |
attention_output = self.attention(normed_hidden_states) | |
hidden_states = hidden_states + self.dropout(attention_output) | |
return hidden_states | |
class T5LayerCrossAttention(nn.Module): | |
r""" | |
T5 style cross-attention layer. | |
Args: | |
d_model (`int`): | |
Size of the input hidden states. | |
d_kv (`int`): | |
Size of the key-value projection vectors. | |
num_heads (`int`): | |
Number of attention heads. | |
dropout_rate (`float`): | |
Dropout probability. | |
layer_norm_epsilon (`float`): | |
A small value used for numerical stability to avoid dividing by zero. | |
""" | |
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): | |
super().__init__() | |
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) | |
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
key_value_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
normed_hidden_states = self.layer_norm(hidden_states) | |
attention_output = self.attention( | |
normed_hidden_states, | |
encoder_hidden_states=key_value_states, | |
attention_mask=attention_mask.squeeze(1), | |
) | |
layer_output = hidden_states + self.dropout(attention_output) | |
return layer_output | |
class T5LayerFFCond(nn.Module): | |
r""" | |
T5 style feed-forward conditional layer. | |
Args: | |
d_model (`int`): | |
Size of the input hidden states. | |
d_ff (`int`): | |
Size of the intermediate feed-forward layer. | |
dropout_rate (`float`): | |
Dropout probability. | |
layer_norm_epsilon (`float`): | |
A small value used for numerical stability to avoid dividing by zero. | |
""" | |
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): | |
super().__init__() | |
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) | |
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) | |
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward( | |
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None | |
) -> torch.FloatTensor: | |
forwarded_states = self.layer_norm(hidden_states) | |
if conditioning_emb is not None: | |
forwarded_states = self.film(forwarded_states, conditioning_emb) | |
forwarded_states = self.DenseReluDense(forwarded_states) | |
hidden_states = hidden_states + self.dropout(forwarded_states) | |
return hidden_states | |
class T5DenseGatedActDense(nn.Module): | |
r""" | |
T5 style feed-forward layer with gated activations and dropout. | |
Args: | |
d_model (`int`): | |
Size of the input hidden states. | |
d_ff (`int`): | |
Size of the intermediate feed-forward layer. | |
dropout_rate (`float`): | |
Dropout probability. | |
""" | |
def __init__(self, d_model: int, d_ff: int, dropout_rate: float): | |
super().__init__() | |
self.wi_0 = nn.Linear(d_model, d_ff, bias=False) | |
self.wi_1 = nn.Linear(d_model, d_ff, bias=False) | |
self.wo = nn.Linear(d_ff, d_model, bias=False) | |
self.dropout = nn.Dropout(dropout_rate) | |
self.act = NewGELUActivation() | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_gelu = self.act(self.wi_0(hidden_states)) | |
hidden_linear = self.wi_1(hidden_states) | |
hidden_states = hidden_gelu * hidden_linear | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.wo(hidden_states) | |
return hidden_states | |
class T5LayerNorm(nn.Module): | |
r""" | |
T5 style layer normalization module. | |
Args: | |
hidden_size (`int`): | |
Size of the input hidden states. | |
eps (`float`, `optional`, defaults to `1e-6`): | |
A small value used for numerical stability to avoid dividing by zero. | |
""" | |
def __init__(self, hidden_size: int, eps: float = 1e-6): | |
""" | |
Construct a layernorm module in the T5 style. No bias and no subtraction of mean. | |
""" | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean | |
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated | |
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for | |
# half-precision inputs is done in fp32 | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
# convert into half-precision if necessary | |
if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
hidden_states = hidden_states.to(self.weight.dtype) | |
return self.weight * hidden_states | |
class NewGELUActivation(nn.Module): | |
""" | |
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see | |
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 | |
""" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) | |
class T5FiLMLayer(nn.Module): | |
""" | |
T5 style FiLM Layer. | |
Args: | |
in_features (`int`): | |
Number of input features. | |
out_features (`int`): | |
Number of output features. | |
""" | |
def __init__(self, in_features: int, out_features: int): | |
super().__init__() | |
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) | |
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor: | |
emb = self.scale_bias(conditioning_emb) | |
scale, shift = torch.chunk(emb, 2, -1) | |
x = x * (1 + scale) + shift | |
return x | |
from typing import Optional | |
from torch import nn | |
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput | |
class DualTransformer2DModel(nn.Module): | |
""" | |
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. | |
in_channels (`int`, *optional*): | |
Pass if the input is continuous. The number of channels in the input and output. | |
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | |
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. | |
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. | |
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. | |
Note that this is fixed at training time as it is used for learning a number of position embeddings. See | |
`ImagePositionalEmbeddings`. | |
num_vector_embeds (`int`, *optional*): | |
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. | |
Includes the class for the masked latent pixel. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. | |
The number of diffusion steps used during training. Note that this is fixed at training time as it is used | |
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for | |
up to but not more than steps than `num_embeds_ada_norm`. | |
attention_bias (`bool`, *optional*): | |
Configure if the TransformerBlocks' attention should contain a bias parameter. | |
""" | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
num_layers: int = 1, | |
dropout: float = 0.0, | |
norm_num_groups: int = 32, | |
cross_attention_dim: Optional[int] = None, | |
attention_bias: bool = False, | |
sample_size: Optional[int] = None, | |
num_vector_embeds: Optional[int] = None, | |
activation_fn: str = "geglu", | |
num_embeds_ada_norm: Optional[int] = None, | |
): | |
super().__init__() | |
self.transformers = nn.ModuleList( | |
[ | |
Transformer2DModel( | |
num_attention_heads=num_attention_heads, | |
attention_head_dim=attention_head_dim, | |
in_channels=in_channels, | |
num_layers=num_layers, | |
dropout=dropout, | |
norm_num_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
attention_bias=attention_bias, | |
sample_size=sample_size, | |
num_vector_embeds=num_vector_embeds, | |
activation_fn=activation_fn, | |
num_embeds_ada_norm=num_embeds_ada_norm, | |
) | |
for _ in range(2) | |
] | |
) | |
# Variables that can be set by a pipeline: | |
# The ratio of transformer1 to transformer2's output states to be combined during inference | |
self.mix_ratio = 0.5 | |
# The shape of `encoder_hidden_states` is expected to be | |
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` | |
self.condition_lengths = [77, 257] | |
# Which transformer to use to encode which condition. | |
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` | |
self.transformer_index_for_condition = [1, 0] | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states, | |
timestep=None, | |
attention_mask=None, | |
cross_attention_kwargs=None, | |
return_dict: bool = True, | |
): | |
""" | |
Args: | |
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. | |
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input | |
hidden_states. | |
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): | |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
self-attention. | |
timestep ( `torch.long`, *optional*): | |
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. | |
attention_mask (`torch.FloatTensor`, *optional*): | |
Optional attention mask to be applied in Attention. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: | |
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When | |
returning a tuple, the first element is the sample tensor. | |
""" | |
input_states = hidden_states | |
encoded_states = [] | |
tokens_start = 0 | |
# attention_mask is not used yet | |
for i in range(2): | |
# for each of the two transformers, pass the corresponding condition tokens | |
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] | |
transformer_index = self.transformer_index_for_condition[i] | |
encoded_state = self.transformers[transformer_index]( | |
input_states, | |
encoder_hidden_states=condition_state, | |
timestep=timestep, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
encoded_states.append(encoded_state - input_states) | |
tokens_start += self.condition_lengths[i] | |
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) | |
output_states = output_states + input_states | |
if not return_dict: | |
return (output_states,) | |
return Transformer2DModelOutput(sample=output_states) | |
# coding=utf-8 | |
# Copyright 2024 HuggingFace Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..utils import USE_PEFT_BACKEND | |
from .lora import LoRACompatibleLinear | |
ACTIVATION_FUNCTIONS = { | |
"swish": nn.SiLU(), | |
"silu": nn.SiLU(), | |
"mish": nn.Mish(), | |
"gelu": nn.GELU(), | |
"relu": nn.ReLU(), | |
} | |
def get_activation(act_fn: str) -> nn.Module: | |
"""Helper function to get activation function from string. | |
Args: | |
act_fn (str): Name of activation function. | |
Returns: | |
nn.Module: Activation function. | |
""" | |
act_fn = act_fn.lower() | |
if act_fn in ACTIVATION_FUNCTIONS: | |
return ACTIVATION_FUNCTIONS[act_fn] | |
else: | |
raise ValueError(f"Unsupported activation function: {act_fn}") | |
class GELU(nn.Module): | |
r""" | |
GELU activation function with tanh approximation support with `approximate="tanh"`. | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. | |
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | |
""" | |
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out, bias=bias) | |
self.approximate = approximate | |
def gelu(self, gate: torch.Tensor) -> torch.Tensor: | |
if gate.device.type != "mps": | |
return F.gelu(gate, approximate=self.approximate) | |
# mps: gelu is not implemented for float16 | |
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) | |
def forward(self, hidden_states): | |
hidden_states = self.proj(hidden_states) | |
hidden_states = self.gelu(hidden_states) | |
return hidden_states | |
class GEGLU(nn.Module): | |
r""" | |
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | |
""" | |
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): | |
super().__init__() | |
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear | |
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) | |
def gelu(self, gate: torch.Tensor) -> torch.Tensor: | |
if gate.device.type != "mps": | |
return F.gelu(gate) | |
# mps: gelu is not implemented for float16 | |
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) | |
def forward(self, hidden_states, scale: float = 1.0): | |
args = () if USE_PEFT_BACKEND else (scale,) | |
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) | |
return hidden_states * self.gelu(gate) | |
class ApproximateGELU(nn.Module): | |
r""" | |
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this | |
[paper](https://arxiv.org/abs/1606.08415). | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | |
""" | |
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out, bias=bias) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.proj(x) | |
return x * torch.sigmoid(1.702 * x) | |
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers | |
import math | |
from functools import partial | |
from typing import Tuple | |
import flax | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
from flax.core.frozen_dict import FrozenDict | |
from ..configuration_utils import ConfigMixin, flax_register_to_config | |
from ..utils import BaseOutput | |
from .modeling_flax_utils import FlaxModelMixin | |
@flax.struct.dataclass | |
class FlaxDecoderOutput(BaseOutput): | |
""" | |
Output of decoding method. | |
Args: | |
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): | |
The decoded output sample from the last layer of the model. | |
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): | |
The `dtype` of the parameters. | |
""" | |
sample: jnp.ndarray | |
@flax.struct.dataclass | |
class FlaxAutoencoderKLOutput(BaseOutput): | |
""" | |
Output of AutoencoderKL encoding method. | |
Args: | |
latent_dist (`FlaxDiagonalGaussianDistribution`): | |
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. | |
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. | |
""" | |
latent_dist: "FlaxDiagonalGaussianDistribution" | |
class FlaxUpsample2D(nn.Module): | |
""" | |
Flax implementation of 2D Upsample layer | |
Args: | |
in_channels (`int`): | |
Input channels | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.conv = nn.Conv( | |
self.in_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
batch, height, width, channels = hidden_states.shape | |
hidden_states = jax.image.resize( | |
hidden_states, | |
shape=(batch, height * 2, width * 2, channels), | |
method="nearest", | |
) | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class FlaxDownsample2D(nn.Module): | |
""" | |
Flax implementation of 2D Downsample layer | |
Args: | |
in_channels (`int`): | |
Input channels | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.conv = nn.Conv( | |
self.in_channels, | |
kernel_size=(3, 3), | |
strides=(2, 2), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim | |
hidden_states = jnp.pad(hidden_states, pad_width=pad) | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class FlaxResnetBlock2D(nn.Module): | |
""" | |
Flax implementation of 2D Resnet Block. | |
Args: | |
in_channels (`int`): | |
Input channels | |
out_channels (`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
groups (:obj:`int`, *optional*, defaults to `32`): | |
The number of groups to use for group norm. | |
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): | |
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int = None | |
dropout: float = 0.0 | |
groups: int = 32 | |
use_nin_shortcut: bool = None | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
out_channels = self.in_channels if self.out_channels is None else self.out_channels | |
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) | |
self.conv1 = nn.Conv( | |
out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) | |
self.dropout_layer = nn.Dropout(self.dropout) | |
self.conv2 = nn.Conv( | |
out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut | |
self.conv_shortcut = None | |
if use_nin_shortcut: | |
self.conv_shortcut = nn.Conv( | |
out_channels, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states, deterministic=True): | |
residual = hidden_states | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.dropout_layer(hidden_states, deterministic) | |
hidden_states = self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
residual = self.conv_shortcut(residual) | |
return hidden_states + residual | |
class FlaxAttentionBlock(nn.Module): | |
r""" | |
Flax Convolutional based multi-head attention block for diffusion-based VAE. | |
Parameters: | |
channels (:obj:`int`): | |
Input channels | |
num_head_channels (:obj:`int`, *optional*, defaults to `None`): | |
Number of attention heads | |
num_groups (:obj:`int`, *optional*, defaults to `32`): | |
The number of groups to use for group norm | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
channels: int | |
num_head_channels: int = None | |
num_groups: int = 32 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 | |
dense = partial(nn.Dense, self.channels, dtype=self.dtype) | |
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) | |
self.query, self.key, self.value = dense(), dense(), dense() | |
self.proj_attn = dense() | |
def transpose_for_scores(self, projection): | |
new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) | |
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) | |
new_projection = projection.reshape(new_projection_shape) | |
# (B, T, H, D) -> (B, H, T, D) | |
new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) | |
return new_projection | |
def __call__(self, hidden_states): | |
residual = hidden_states | |
batch, height, width, channels = hidden_states.shape | |
hidden_states = self.group_norm(hidden_states) | |
hidden_states = hidden_states.reshape((batch, height * width, channels)) | |
query = self.query(hidden_states) | |
key = self.key(hidden_states) | |
value = self.value(hidden_states) | |
# transpose | |
query = self.transpose_for_scores(query) | |
key = self.transpose_for_scores(key) | |
value = self.transpose_for_scores(value) | |
# compute attentions | |
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) | |
attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) | |
attn_weights = nn.softmax(attn_weights, axis=-1) | |
# attend to values | |
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) | |
hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) | |
new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) | |
hidden_states = hidden_states.reshape(new_hidden_states_shape) | |
hidden_states = self.proj_attn(hidden_states) | |
hidden_states = hidden_states.reshape((batch, height, width, channels)) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class FlaxDownEncoderBlock2D(nn.Module): | |
r""" | |
Flax Resnet blocks-based Encoder block for diffusion-based VAE. | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of Resnet layer block | |
resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
The number of groups to use for the Resnet block group norm | |
add_downsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add downsample layer | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
resnet_groups: int = 32 | |
add_downsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
resnets = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=self.out_channels, | |
dropout=self.dropout, | |
groups=self.resnet_groups, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_downsample: | |
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, deterministic=True): | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, deterministic=deterministic) | |
if self.add_downsample: | |
hidden_states = self.downsamplers_0(hidden_states) | |
return hidden_states | |
class FlaxUpDecoderBlock2D(nn.Module): | |
r""" | |
Flax Resnet blocks-based Decoder block for diffusion-based VAE. | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of Resnet layer block | |
resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
The number of groups to use for the Resnet block group norm | |
add_upsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add upsample layer | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
resnet_groups: int = 32 | |
add_upsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
resnets = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=self.out_channels, | |
dropout=self.dropout, | |
groups=self.resnet_groups, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_upsample: | |
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, deterministic=True): | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, deterministic=deterministic) | |
if self.add_upsample: | |
hidden_states = self.upsamplers_0(hidden_states) | |
return hidden_states | |
class FlaxUNetMidBlock2D(nn.Module): | |
r""" | |
Flax Unet Mid-Block module. | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of Resnet layer block | |
resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
The number of groups to use for the Resnet and Attention block group norm | |
num_attention_heads (:obj:`int`, *optional*, defaults to `1`): | |
Number of attention heads for each attention block | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
resnet_groups: int = 32 | |
num_attention_heads: int = 1 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) | |
# there is always at least one resnet | |
resnets = [ | |
FlaxResnetBlock2D( | |
in_channels=self.in_channels, | |
out_channels=self.in_channels, | |
dropout=self.dropout, | |
groups=resnet_groups, | |
dtype=self.dtype, | |
) | |
] | |
attentions = [] | |
for _ in range(self.num_layers): | |
attn_block = FlaxAttentionBlock( | |
channels=self.in_channels, | |
num_head_channels=self.num_attention_heads, | |
num_groups=resnet_groups, | |
dtype=self.dtype, | |
) | |
attentions.append(attn_block) | |
res_block = FlaxResnetBlock2D( | |
in_channels=self.in_channels, | |
out_channels=self.in_channels, | |
dropout=self.dropout, | |
groups=resnet_groups, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
def __call__(self, hidden_states, deterministic=True): | |
hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) | |
for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
hidden_states = attn(hidden_states) | |
hidden_states = resnet(hidden_states, deterministic=deterministic) | |
return hidden_states | |
class FlaxEncoder(nn.Module): | |
r""" | |
Flax Implementation of VAE Encoder. | |
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to | |
general usage and behavior. | |
Finally, this model supports inherent JAX features such as: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
in_channels (:obj:`int`, *optional*, defaults to 3): | |
Input channels | |
out_channels (:obj:`int`, *optional*, defaults to 3): | |
Output channels | |
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): | |
DownEncoder block type | |
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): | |
Tuple containing the number of output channels for each block | |
layers_per_block (:obj:`int`, *optional*, defaults to `2`): | |
Number of Resnet layer for each block | |
norm_num_groups (:obj:`int`, *optional*, defaults to `32`): | |
norm num group | |
act_fn (:obj:`str`, *optional*, defaults to `silu`): | |
Activation function | |
double_z (:obj:`bool`, *optional*, defaults to `False`): | |
Whether to double the last output channels | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int = 3 | |
out_channels: int = 3 | |
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) | |
block_out_channels: Tuple[int] = (64,) | |
layers_per_block: int = 2 | |
norm_num_groups: int = 32 | |
act_fn: str = "silu" | |
double_z: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
block_out_channels = self.block_out_channels | |
# in | |
self.conv_in = nn.Conv( | |
block_out_channels[0], | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
# downsampling | |
down_blocks = [] | |
output_channel = block_out_channels[0] | |
for i, _ in enumerate(self.down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = FlaxDownEncoderBlock2D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
num_layers=self.layers_per_block, | |
resnet_groups=self.norm_num_groups, | |
add_downsample=not is_final_block, | |
dtype=self.dtype, | |
) | |
down_blocks.append(down_block) | |
self.down_blocks = down_blocks | |
# middle | |
self.mid_block = FlaxUNetMidBlock2D( | |
in_channels=block_out_channels[-1], | |
resnet_groups=self.norm_num_groups, | |
num_attention_heads=None, | |
dtype=self.dtype, | |
) | |
# end | |
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels | |
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) | |
self.conv_out = nn.Conv( | |
conv_out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, sample, deterministic: bool = True): | |
# in | |
sample = self.conv_in(sample) | |
# downsampling | |
for block in self.down_blocks: | |
sample = block(sample, deterministic=deterministic) | |
# middle | |
sample = self.mid_block(sample, deterministic=deterministic) | |
# end | |
sample = self.conv_norm_out(sample) | |
sample = nn.swish(sample) | |
sample = self.conv_out(sample) | |
return sample | |
class FlaxDecoder(nn.Module): | |
r""" | |
Flax Implementation of VAE Decoder. | |
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to | |
general usage and behavior. | |
Finally, this model supports inherent JAX features such as: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
in_channels (:obj:`int`, *optional*, defaults to 3): | |
Input channels | |
out_channels (:obj:`int`, *optional*, defaults to 3): | |
Output channels | |
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): | |
UpDecoder block type | |
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): | |
Tuple containing the number of output channels for each block | |
layers_per_block (:obj:`int`, *optional*, defaults to `2`): | |
Number of Resnet layer for each block | |
norm_num_groups (:obj:`int`, *optional*, defaults to `32`): | |
norm num group | |
act_fn (:obj:`str`, *optional*, defaults to `silu`): | |
Activation function | |
double_z (:obj:`bool`, *optional*, defaults to `False`): | |
Whether to double the last output channels | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
parameters `dtype` | |
""" | |
in_channels: int = 3 | |
out_channels: int = 3 | |
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) | |
block_out_channels: int = (64,) | |
layers_per_block: int = 2 | |
norm_num_groups: int = 32 | |
act_fn: str = "silu" | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
block_out_channels = self.block_out_channels | |
# z to block_in | |
self.conv_in = nn.Conv( | |
block_out_channels[-1], | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
# middle | |
self.mid_block = FlaxUNetMidBlock2D( | |
in_channels=block_out_channels[-1], | |
resnet_groups=self.norm_num_groups, | |
num_attention_heads=None, | |
dtype=self.dtype, | |
) | |
# upsampling | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
up_blocks = [] | |
for i, _ in enumerate(self.up_block_types): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
up_block = FlaxUpDecoderBlock2D( | |
in_channels=prev_output_channel, | |
out_channels=output_channel, | |
num_layers=self.layers_per_block + 1, | |
resnet_groups=self.norm_num_groups, | |
add_upsample=not is_final_block, | |
dtype=self.dtype, | |
) | |
up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
self.up_blocks = up_blocks | |
# end | |
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) | |
self.conv_out = nn.Conv( | |
self.out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, sample, deterministic: bool = True): | |
# z to block_in | |
sample = self.conv_in(sample) | |
# middle | |
sample = self.mid_block(sample, deterministic=deterministic) | |
# upsampling | |
for block in self.up_blocks: | |
sample = block(sample, deterministic=deterministic) | |
sample = self.conv_norm_out(sample) | |
sample = nn.swish(sample) | |
sample = self.conv_out(sample) | |
return sample | |
class FlaxDiagonalGaussianDistribution(object): | |
def __init__(self, parameters, deterministic=False): | |
# Last axis to account for channels-last | |
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) | |
self.logvar = jnp.clip(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = jnp.exp(0.5 * self.logvar) | |
self.var = jnp.exp(self.logvar) | |
if self.deterministic: | |
self.var = self.std = jnp.zeros_like(self.mean) | |
def sample(self, key): | |
return self.mean + self.std * jax.random.normal(key, self.mean.shape) | |
def kl(self, other=None): | |
if self.deterministic: | |
return jnp.array([0.0]) | |
if other is None: | |
return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) | |
return 0.5 * jnp.sum( | |
jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, | |
axis=[1, 2, 3], | |
) | |
def nll(self, sample, axis=[1, 2, 3]): | |
if self.deterministic: | |
return jnp.array([0.0]) | |
logtwopi = jnp.log(2.0 * jnp.pi) | |
return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) | |
def mode(self): | |
return self.mean | |
@flax_register_to_config | |
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): | |
r""" | |
Flax implementation of a VAE model with KL loss for decoding latent representations. | |
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods | |
implemented for all models (such as downloading or saving). | |
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its | |
general usage and behavior. | |
Inherent JAX features such as the following are supported: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
in_channels (`int`, *optional*, defaults to 3): | |
Number of channels in the input image. | |
out_channels (`int`, *optional*, defaults to 3): | |
Number of channels in the output. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): | |
Tuple of downsample block types. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): | |
Tuple of upsample block types. | |
block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): | |
Tuple of block output channels. | |
layers_per_block (`int`, *optional*, defaults to `2`): | |
Number of ResNet layer for each block. | |
act_fn (`str`, *optional*, defaults to `silu`): | |
The activation function to use. | |
latent_channels (`int`, *optional*, defaults to `4`): | |
Number of channels in the latent space. | |
norm_num_groups (`int`, *optional*, defaults to `32`): | |
The number of groups for normalization. | |
sample_size (`int`, *optional*, defaults to 32): | |
Sample input size. | |
scaling_factor (`float`, *optional*, defaults to 0.18215): | |
The component-wise standard deviation of the trained latent space computed using the first batch of the | |
training set. This is used to scale the latent space to have unit variance when training the diffusion | |
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the | |
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 | |
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image | |
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. | |
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): | |
The `dtype` of the parameters. | |
""" | |
in_channels: int = 3 | |
out_channels: int = 3 | |
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) | |
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) | |
block_out_channels: Tuple[int] = (64,) | |
layers_per_block: int = 1 | |
act_fn: str = "silu" | |
latent_channels: int = 4 | |
norm_num_groups: int = 32 | |
sample_size: int = 32 | |
scaling_factor: float = 0.18215 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.encoder = FlaxEncoder( | |
in_channels=self.config.in_channels, | |
out_channels=self.config.latent_channels, | |
down_block_types=self.config.down_block_types, | |
block_out_channels=self.config.block_out_channels, | |
layers_per_block=self.config.layers_per_block, | |
act_fn=self.config.act_fn, | |
norm_num_groups=self.config.norm_num_groups, | |
double_z=True, | |
dtype=self.dtype, | |
) | |
self.decoder = FlaxDecoder( | |
in_channels=self.config.latent_channels, | |
out_channels=self.config.out_channels, | |
up_block_types=self.config.up_block_types, | |
block_out_channels=self.config.block_out_channels, | |
layers_per_block=self.config.layers_per_block, | |
norm_num_groups=self.config.norm_num_groups, | |
act_fn=self.config.act_fn, | |
dtype=self.dtype, | |
) | |
self.quant_conv = nn.Conv( | |
2 * self.config.latent_channels, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
self.post_quant_conv = nn.Conv( | |
self.config.latent_channels, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def init_weights(self, rng: jax.Array) -> FrozenDict: | |
# init input tensors | |
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) | |
sample = jnp.zeros(sample_shape, dtype=jnp.float32) | |
params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) | |
rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} | |
return self.init(rngs, sample)["params"] | |
def encode(self, sample, deterministic: bool = True, return_dict: bool = True): | |
sample = jnp.transpose(sample, (0, 2, 3, 1)) | |
hidden_states = self.encoder(sample, deterministic=deterministic) | |
moments = self.quant_conv(hidden_states) | |
posterior = FlaxDiagonalGaussianDistribution(moments) | |
if not return_dict: | |
return (posterior,) | |
return FlaxAutoencoderKLOutput(latent_dist=posterior) | |
def decode(self, latents, deterministic: bool = True, return_dict: bool = True): | |
if latents.shape[-1] != self.config.latent_channels: | |
latents = jnp.transpose(latents, (0, 2, 3, 1)) | |
hidden_states = self.post_quant_conv(latents) | |
hidden_states = self.decoder(hidden_states, deterministic=deterministic) | |
hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) | |
if not return_dict: | |
return (hidden_states,) | |
return FlaxDecoderOutput(sample=hidden_states) | |
def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): | |
posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) | |
if sample_posterior: | |
rng = self.make_rng("gaussian") | |
hidden_states = posterior.latent_dist.sample(rng) | |
else: | |
hidden_states = posterior.latent_dist.mode() | |
sample = self.decode(hidden_states, return_dict=return_dict).sample | |
if not return_dict: | |
return (sample,) | |
return FlaxDecoderOutput(sample=sample) | |
from ..utils import deprecate | |
from .unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput | |
class UNet2DConditionOutput(UNet2DConditionOutput): | |
deprecation_message = "Importing `UNet2DConditionOutput` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput`, instead." | |
deprecate("UNet2DConditionOutput", "0.29", deprecation_message) | |
class UNet2DConditionModel(UNet2DConditionModel): | |
deprecation_message = "Importing `UNet2DConditionModel` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel`, instead." | |
deprecate("UNet2DConditionModel", "0.29", deprecation_message) | |
# IMPORTANT: # | |
################################################################### | |
# ----------------------------------------------------------------# | |
# This file is deprecated and will be removed soon # | |
# (as soon as PEFT will become a required dependency for LoRA) # | |
# ----------------------------------------------------------------# | |
################################################################### | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..utils import deprecate, logging | |
from ..utils.import_utils import is_transformers_available | |
if is_transformers_available(): | |
from transformers import CLIPTextModel, CLIPTextModelWithProjection | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def text_encoder_attn_modules(text_encoder): | |
attn_modules = [] | |
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): | |
for i, layer in enumerate(text_encoder.text_model.encoder.layers): | |
name = f"text_model.encoder.layers.{i}.self_attn" | |
mod = layer.self_attn | |
attn_modules.append((name, mod)) | |
else: | |
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") | |
return attn_modules | |
def text_encoder_mlp_modules(text_encoder): | |
mlp_modules = [] | |
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): | |
for i, layer in enumerate(text_encoder.text_model.encoder.layers): | |
mlp_mod = layer.mlp | |
name = f"text_model.encoder.layers.{i}.mlp" | |
mlp_modules.append((name, mlp_mod)) | |
else: | |
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") | |
return mlp_modules | |
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): | |
for _, attn_module in text_encoder_attn_modules(text_encoder): | |
if isinstance(attn_module.q_proj, PatchedLoraProjection): | |
attn_module.q_proj.lora_scale = lora_scale | |
attn_module.k_proj.lora_scale = lora_scale | |
attn_module.v_proj.lora_scale = lora_scale | |
attn_module.out_proj.lora_scale = lora_scale | |
for _, mlp_module in text_encoder_mlp_modules(text_encoder): | |
if isinstance(mlp_module.fc1, PatchedLoraProjection): | |
mlp_module.fc1.lora_scale = lora_scale | |
mlp_module.fc2.lora_scale = lora_scale | |
class PatchedLoraProjection(torch.nn.Module): | |
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): | |
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." | |
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message) | |
super().__init__() | |
from ..models.lora import LoRALinearLayer | |
self.regular_linear_layer = regular_linear_layer | |
device = self.regular_linear_layer.weight.device | |
if dtype is None: | |
dtype = self.regular_linear_layer.weight.dtype | |
self.lora_linear_layer = LoRALinearLayer( | |
self.regular_linear_layer.in_features, | |
self.regular_linear_layer.out_features, | |
network_alpha=network_alpha, | |
device=device, | |
dtype=dtype, | |
rank=rank, | |
) | |
self.lora_scale = lora_scale | |
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved | |
# when saving the whole text encoder model and when LoRA is unloaded or fused | |
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
if self.lora_linear_layer is None: | |
return self.regular_linear_layer.state_dict( | |
*args, destination=destination, prefix=prefix, keep_vars=keep_vars | |
) | |
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) | |
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): | |
if self.lora_linear_layer is None: | |
return | |
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device | |
w_orig = self.regular_linear_layer.weight.data.float() | |
w_up = self.lora_linear_layer.up.weight.data.float() | |
w_down = self.lora_linear_layer.down.weight.data.float() | |
if self.lora_linear_layer.network_alpha is not None: | |
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank | |
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) | |
if safe_fusing and torch.isnan(fused_weight).any().item(): | |
raise ValueError( | |
"This LoRA weight seems to be broken. " | |
f"Encountered NaN values when trying to fuse LoRA weights for {self}." | |
"LoRA weights will not be fused." | |
) | |
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) | |
# we can drop the lora layer now | |
self.lora_linear_layer = None | |
# offload the up and down matrices to CPU to not blow the memory | |
self.w_up = w_up.cpu() | |
self.w_down = w_down.cpu() | |
self.lora_scale = lora_scale | |
def _unfuse_lora(self): | |
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): | |
return | |
fused_weight = self.regular_linear_layer.weight.data | |
dtype, device = fused_weight.dtype, fused_weight.device | |
w_up = self.w_up.to(device=device).float() | |
w_down = self.w_down.to(device).float() | |
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) | |
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) | |
self.w_up = None | |
self.w_down = None | |
def forward(self, input): | |
if self.lora_scale is None: | |
self.lora_scale = 1.0 | |
if self.lora_linear_layer is None: | |
return self.regular_linear_layer(input) | |
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input)) | |
class LoRALinearLayer(nn.Module): | |
r""" | |
A linear layer that is used with LoRA. | |
Parameters: | |
in_features (`int`): | |
Number of input features. | |
out_features (`int`): | |
Number of output features. | |
rank (`int`, `optional`, defaults to 4): | |
The rank of the LoRA layer. | |
network_alpha (`float`, `optional`, defaults to `None`): | |
The value of the network alpha used for stable learning and preventing underflow. This value has the same | |
meaning as the `--network_alpha` option in the kohya-ss trainer script. See | |
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | |
device (`torch.device`, `optional`, defaults to `None`): | |
The device to use for the layer's weights. | |
dtype (`torch.dtype`, `optional`, defaults to `None`): | |
The dtype to use for the layer's weights. | |
""" | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
rank: int = 4, | |
network_alpha: Optional[float] = None, | |
device: Optional[Union[torch.device, str]] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
super().__init__() | |
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) | |
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) | |
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. | |
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | |
self.network_alpha = network_alpha | |
self.rank = rank | |
self.out_features = out_features | |
self.in_features = in_features | |
nn.init.normal_(self.down.weight, std=1 / rank) | |
nn.init.zeros_(self.up.weight) | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
orig_dtype = hidden_states.dtype | |
dtype = self.down.weight.dtype | |
down_hidden_states = self.down(hidden_states.to(dtype)) | |
up_hidden_states = self.up(down_hidden_states) | |
if self.network_alpha is not None: | |
up_hidden_states *= self.network_alpha / self.rank | |
return up_hidden_states.to(orig_dtype) | |
class LoRAConv2dLayer(nn.Module): | |
r""" | |
A convolutional layer that is used with LoRA. | |
Parameters: | |
in_features (`int`): | |
Number of input features. | |
out_features (`int`): | |
Number of output features. | |
rank (`int`, `optional`, defaults to 4): | |
The rank of the LoRA layer. | |
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): | |
The kernel size of the convolution. | |
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): | |
The stride of the convolution. | |
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): | |
The padding of the convolution. | |
network_alpha (`float`, `optional`, defaults to `None`): | |
The value of the network alpha used for stable learning and preventing underflow. This value has the same | |
meaning as the `--network_alpha` option in the kohya-ss trainer script. See | |
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | |
""" | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
rank: int = 4, | |
kernel_size: Union[int, Tuple[int, int]] = (1, 1), | |
stride: Union[int, Tuple[int, int]] = (1, 1), | |
padding: Union[int, Tuple[int, int], str] = 0, | |
network_alpha: Optional[float] = None, | |
): | |
super().__init__() | |
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | |
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer | |
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 | |
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) | |
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. | |
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | |
self.network_alpha = network_alpha | |
self.rank = rank | |
nn.init.normal_(self.down.weight, std=1 / rank) | |
nn.init.zeros_(self.up.weight) | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
orig_dtype = hidden_states.dtype | |
dtype = self.down.weight.dtype | |
down_hidden_states = self.down(hidden_states.to(dtype)) | |
up_hidden_states = self.up(down_hidden_states) | |
if self.network_alpha is not None: | |
up_hidden_states *= self.network_alpha / self.rank | |
return up_hidden_states.to(orig_dtype) | |
class LoRACompatibleConv(nn.Conv2d): | |
""" | |
A convolutional layer that can be used with LoRA. | |
""" | |
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): | |
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." | |
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message) | |
super().__init__(*args, **kwargs) | |
self.lora_layer = lora_layer | |
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): | |
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." | |
deprecate("set_lora_layer", "1.0.0", deprecation_message) | |
self.lora_layer = lora_layer | |
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): | |
if self.lora_layer is None: | |
return | |
dtype, device = self.weight.data.dtype, self.weight.data.device | |
w_orig = self.weight.data.float() | |
w_up = self.lora_layer.up.weight.data.float() | |
w_down = self.lora_layer.down.weight.data.float() | |
if self.lora_layer.network_alpha is not None: | |
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank | |
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) | |
fusion = fusion.reshape((w_orig.shape)) | |
fused_weight = w_orig + (lora_scale * fusion) | |
if safe_fusing and torch.isnan(fused_weight).any().item(): | |
raise ValueError( | |
"This LoRA weight seems to be broken. " | |
f"Encountered NaN values when trying to fuse LoRA weights for {self}." | |
"LoRA weights will not be fused." | |
) | |
self.weight.data = fused_weight.to(device=device, dtype=dtype) | |
# we can drop the lora layer now | |
self.lora_layer = None | |
# offload the up and down matrices to CPU to not blow the memory | |
self.w_up = w_up.cpu() | |
self.w_down = w_down.cpu() | |
self._lora_scale = lora_scale | |
def _unfuse_lora(self): | |
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): | |
return | |
fused_weight = self.weight.data | |
dtype, device = fused_weight.data.dtype, fused_weight.data.device | |
self.w_up = self.w_up.to(device=device).float() | |
self.w_down = self.w_down.to(device).float() | |
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) | |
fusion = fusion.reshape((fused_weight.shape)) | |
unfused_weight = fused_weight.float() - (self._lora_scale * fusion) | |
self.weight.data = unfused_weight.to(device=device, dtype=dtype) | |
self.w_up = None | |
self.w_down = None | |
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: | |
if self.padding_mode != "zeros": | |
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) | |
padding = (0, 0) | |
else: | |
padding = self.padding | |
original_outputs = F.conv2d( | |
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups | |
) | |
if self.lora_layer is None: | |
return original_outputs | |
else: | |
return original_outputs + (scale * self.lora_layer(hidden_states)) | |
class LoRACompatibleLinear(nn.Linear): | |
""" | |
A Linear layer that can be used with LoRA. | |
""" | |
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): | |
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." | |
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message) | |
super().__init__(*args, **kwargs) | |
self.lora_layer = lora_layer | |
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): | |
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." | |
deprecate("set_lora_layer", "1.0.0", deprecation_message) | |
self.lora_layer = lora_layer | |
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): | |
if self.lora_layer is None: | |
return | |
dtype, device = self.weight.data.dtype, self.weight.data.device | |
w_orig = self.weight.data.float() | |
w_up = self.lora_layer.up.weight.data.float() | |
w_down = self.lora_layer.down.weight.data.float() | |
if self.lora_layer.network_alpha is not None: | |
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank | |
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) | |
if safe_fusing and torch.isnan(fused_weight).any().item(): | |
raise ValueError( | |
"This LoRA weight seems to be broken. " | |
f"Encountered NaN values when trying to fuse LoRA weights for {self}." | |
"LoRA weights will not be fused." | |
) | |
self.weight.data = fused_weight.to(device=device, dtype=dtype) | |
# we can drop the lora layer now | |
self.lora_layer = None | |
# offload the up and down matrices to CPU to not blow the memory | |
self.w_up = w_up.cpu() | |
self.w_down = w_down.cpu() | |
self._lora_scale = lora_scale | |
def _unfuse_lora(self): | |
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): | |
return | |
fused_weight = self.weight.data | |
dtype, device = fused_weight.dtype, fused_weight.device | |
w_up = self.w_up.to(device=device).float() | |
w_down = self.w_down.to(device).float() | |
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) | |
self.weight.data = unfused_weight.to(device=device, dtype=dtype) | |
self.w_up = None | |
self.w_down = None | |
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: | |
if self.lora_layer is None: | |
out = super().forward(hidden_states) | |
return out | |
else: | |
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) | |
return out | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils import USE_PEFT_BACKEND | |
from .lora import LoRACompatibleConv | |
from .normalization import RMSNorm | |
class Upsample1D(nn.Module): | |
"""A 1D upsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
use_conv_transpose (`bool`, default `False`): | |
option to use a convolution transpose. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
name (`str`, default `conv`): | |
name of the upsampling 1D layer. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
use_conv_transpose: bool = False, | |
out_channels: Optional[int] = None, | |
name: str = "conv", | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_conv_transpose = use_conv_transpose | |
self.name = name | |
self.conv = None | |
if use_conv_transpose: | |
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) | |
elif use_conv: | |
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
assert inputs.shape[1] == self.channels | |
if self.use_conv_transpose: | |
return self.conv(inputs) | |
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") | |
if self.use_conv: | |
outputs = self.conv(outputs) | |
return outputs | |
class Upsample2D(nn.Module): | |
"""A 2D upsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
use_conv_transpose (`bool`, default `False`): | |
option to use a convolution transpose. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
name (`str`, default `conv`): | |
name of the upsampling 2D layer. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
use_conv_transpose: bool = False, | |
out_channels: Optional[int] = None, | |
name: str = "conv", | |
kernel_size: Optional[int] = None, | |
padding=1, | |
norm_type=None, | |
eps=None, | |
elementwise_affine=None, | |
bias=True, | |
interpolate=True, | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_conv_transpose = use_conv_transpose | |
self.name = name | |
self.interpolate = interpolate | |
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | |
if norm_type == "ln_norm": | |
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) | |
elif norm_type == "rms_norm": | |
self.norm = RMSNorm(channels, eps, elementwise_affine) | |
elif norm_type is None: | |
self.norm = None | |
else: | |
raise ValueError(f"unknown norm_type: {norm_type}") | |
conv = None | |
if use_conv_transpose: | |
if kernel_size is None: | |
kernel_size = 4 | |
conv = nn.ConvTranspose2d( | |
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias | |
) | |
elif use_conv: | |
if kernel_size is None: | |
kernel_size = 3 | |
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) | |
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
if name == "conv": | |
self.conv = conv | |
else: | |
self.Conv2d_0 = conv | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
output_size: Optional[int] = None, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
assert hidden_states.shape[1] == self.channels | |
if self.norm is not None: | |
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
if self.use_conv_transpose: | |
return self.conv(hidden_states) | |
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 | |
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch | |
# https://github.com/pytorch/pytorch/issues/86679 | |
dtype = hidden_states.dtype | |
if dtype == torch.bfloat16: | |
hidden_states = hidden_states.to(torch.float32) | |
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
if hidden_states.shape[0] >= 64: | |
hidden_states = hidden_states.contiguous() | |
# if `output_size` is passed we force the interpolation output | |
# size and do not make use of `scale_factor=2` | |
if self.interpolate: | |
if output_size is None: | |
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") | |
else: | |
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") | |
# If the input is bfloat16, we cast back to bfloat16 | |
if dtype == torch.bfloat16: | |
hidden_states = hidden_states.to(dtype) | |
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
if self.use_conv: | |
if self.name == "conv": | |
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: | |
hidden_states = self.conv(hidden_states, scale) | |
else: | |
hidden_states = self.conv(hidden_states) | |
else: | |
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: | |
hidden_states = self.Conv2d_0(hidden_states, scale) | |
else: | |
hidden_states = self.Conv2d_0(hidden_states) | |
return hidden_states | |
class FirUpsample2D(nn.Module): | |
"""A 2D FIR upsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`, optional): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
fir_kernel (`tuple`, default `(1, 3, 3, 1)`): | |
kernel for the FIR filter. | |
""" | |
def __init__( | |
self, | |
channels: Optional[int] = None, | |
out_channels: Optional[int] = None, | |
use_conv: bool = False, | |
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), | |
): | |
super().__init__() | |
out_channels = out_channels if out_channels else channels | |
if use_conv: | |
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.use_conv = use_conv | |
self.fir_kernel = fir_kernel | |
self.out_channels = out_channels | |
def _upsample_2d( | |
self, | |
hidden_states: torch.FloatTensor, | |
weight: Optional[torch.FloatTensor] = None, | |
kernel: Optional[torch.FloatTensor] = None, | |
factor: int = 2, | |
gain: float = 1, | |
) -> torch.FloatTensor: | |
"""Fused `upsample_2d()` followed by `Conv2d()`. | |
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more | |
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of | |
arbitrary order. | |
Args: | |
hidden_states (`torch.FloatTensor`): | |
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. | |
weight (`torch.FloatTensor`, *optional*): | |
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be | |
performed by `inChannels = x.shape[0] // numGroups`. | |
kernel (`torch.FloatTensor`, *optional*): | |
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which | |
corresponds to nearest-neighbor upsampling. | |
factor (`int`, *optional*): Integer upsampling factor (default: 2). | |
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). | |
Returns: | |
output (`torch.FloatTensor`): | |
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same | |
datatype as `hidden_states`. | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
# Setup filter kernel. | |
if kernel is None: | |
kernel = [1] * factor | |
# setup kernel | |
kernel = torch.tensor(kernel, dtype=torch.float32) | |
if kernel.ndim == 1: | |
kernel = torch.outer(kernel, kernel) | |
kernel /= torch.sum(kernel) | |
kernel = kernel * (gain * (factor**2)) | |
if self.use_conv: | |
convH = weight.shape[2] | |
convW = weight.shape[3] | |
inC = weight.shape[1] | |
pad_value = (kernel.shape[0] - factor) - (convW - 1) | |
stride = (factor, factor) | |
# Determine data dimensions. | |
output_shape = ( | |
(hidden_states.shape[2] - 1) * factor + convH, | |
(hidden_states.shape[3] - 1) * factor + convW, | |
) | |
output_padding = ( | |
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, | |
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, | |
) | |
assert output_padding[0] >= 0 and output_padding[1] >= 0 | |
num_groups = hidden_states.shape[1] // inC | |
# Transpose weights. | |
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) | |
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) | |
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) | |
inverse_conv = F.conv_transpose2d( | |
hidden_states, | |
weight, | |
stride=stride, | |
output_padding=output_padding, | |
padding=0, | |
) | |
output = upfirdn2d_native( | |
inverse_conv, | |
torch.tensor(kernel, device=inverse_conv.device), | |
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), | |
) | |
else: | |
pad_value = kernel.shape[0] - factor | |
output = upfirdn2d_native( | |
hidden_states, | |
torch.tensor(kernel, device=hidden_states.device), | |
up=factor, | |
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), | |
) | |
return output | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
if self.use_conv: | |
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) | |
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) | |
else: | |
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) | |
return height | |
class KUpsample2D(nn.Module): | |
r"""A 2D K-upsampling layer. | |
Parameters: | |
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. | |
""" | |
def __init__(self, pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 | |
self.pad = kernel_1d.shape[1] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) | |
weight = inputs.new_zeros( | |
[ | |
inputs.shape[1], | |
inputs.shape[1], | |
self.kernel.shape[0], | |
self.kernel.shape[1], | |
] | |
) | |
indices = torch.arange(inputs.shape[1], device=inputs.device) | |
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) | |
weight[indices, indices] = kernel | |
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) | |
def upfirdn2d_native( | |
tensor: torch.Tensor, | |
kernel: torch.Tensor, | |
up: int = 1, | |
down: int = 1, | |
pad: Tuple[int, int] = (0, 0), | |
) -> torch.Tensor: | |
up_x = up_y = up | |
down_x = down_y = down | |
pad_x0 = pad_y0 = pad[0] | |
pad_x1 = pad_y1 = pad[1] | |
_, channel, in_h, in_w = tensor.shape | |
tensor = tensor.reshape(-1, in_h, in_w, 1) | |
_, in_h, in_w, minor = tensor.shape | |
kernel_h, kernel_w = kernel.shape | |
out = tensor.view(-1, in_h, 1, in_w, 1, minor) | |
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) | |
out = out.view(-1, in_h * up_y, in_w * up_x, minor) | |
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) | |
out = out.to(tensor.device) # Move back to mps if necessary | |
out = out[ | |
:, | |
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), | |
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), | |
:, | |
] | |
out = out.permute(0, 3, 1, 2) | |
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) | |
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) | |
out = F.conv2d(out, w) | |
out = out.reshape( | |
-1, | |
minor, | |
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, | |
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, | |
) | |
out = out.permute(0, 2, 3, 1) | |
out = out[:, ::down_y, ::down_x, :] | |
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 | |
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 | |
return out.view(-1, channel, out_h, out_w) | |
def upsample_2d( | |
hidden_states: torch.FloatTensor, | |
kernel: Optional[torch.FloatTensor] = None, | |
factor: int = 2, | |
gain: float = 1, | |
) -> torch.FloatTensor: | |
r"""Upsample2D a batch of 2D images with the given filter. | |
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given | |
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified | |
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is | |
a: multiple of the upsampling factor. | |
Args: | |
hidden_states (`torch.FloatTensor`): | |
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. | |
kernel (`torch.FloatTensor`, *optional*): | |
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which | |
corresponds to nearest-neighbor upsampling. | |
factor (`int`, *optional*, default to `2`): | |
Integer upsampling factor. | |
gain (`float`, *optional*, default to `1.0`): | |
Scaling factor for signal magnitude (default: 1.0). | |
Returns: | |
output (`torch.FloatTensor`): | |
Tensor of the shape `[N, C, H * factor, W * factor]` | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
if kernel is None: | |
kernel = [1] * factor | |
kernel = torch.tensor(kernel, dtype=torch.float32) | |
if kernel.ndim == 1: | |
kernel = torch.outer(kernel, kernel) | |
kernel /= torch.sum(kernel) | |
kernel = kernel * (gain * (factor**2)) | |
pad_value = kernel.shape[0] - factor | |
output = upfirdn2d_native( | |
hidden_states, | |
kernel.to(device=hidden_states.device), | |
up=factor, | |
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), | |
) | |
return output | |
from ..utils import deprecate | |
from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput | |
class PriorTransformerOutput(PriorTransformerOutput): | |
deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead." | |
deprecate("PriorTransformerOutput", "0.29", deprecation_message) | |
class PriorTransformer(PriorTransformer): | |
deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead." | |
deprecate("PriorTransformer", "0.29", deprecation_message) | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
class FlaxUpsample2D(nn.Module): | |
out_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.conv = nn.Conv( | |
self.out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
batch, height, width, channels = hidden_states.shape | |
hidden_states = jax.image.resize( | |
hidden_states, | |
shape=(batch, height * 2, width * 2, channels), | |
method="nearest", | |
) | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class FlaxDownsample2D(nn.Module): | |
out_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.conv = nn.Conv( | |
self.out_channels, | |
kernel_size=(3, 3), | |
strides=(2, 2), | |
padding=((1, 1), (1, 1)), # padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim | |
# hidden_states = jnp.pad(hidden_states, pad_width=pad) | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class FlaxResnetBlock2D(nn.Module): | |
in_channels: int | |
out_channels: int = None | |
dropout_prob: float = 0.0 | |
use_nin_shortcut: bool = None | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
out_channels = self.in_channels if self.out_channels is None else self.out_channels | |
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) | |
self.conv1 = nn.Conv( | |
out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) | |
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) | |
self.dropout = nn.Dropout(self.dropout_prob) | |
self.conv2 = nn.Conv( | |
out_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut | |
self.conv_shortcut = None | |
if use_nin_shortcut: | |
self.conv_shortcut = nn.Conv( | |
out_channels, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states, temb, deterministic=True): | |
residual = hidden_states | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
temb = self.time_emb_proj(nn.swish(temb)) | |
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) | |
hidden_states = hidden_states + temb | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.dropout(hidden_states, deterministic) | |
hidden_states = self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
residual = self.conv_shortcut(residual) | |
return hidden_states + residual | |
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from typing import Callable, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from ..configuration_utils import ConfigMixin, register_to_config | |
from ..utils import logging | |
from .modeling_utils import ModelMixin | |
logger = logging.get_logger(__name__) | |
class MultiAdapter(ModelMixin): | |
r""" | |
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to | |
user-assigned weighting. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library | |
implements for all the model (such as downloading or saving, etc.) | |
Parameters: | |
adapters (`List[T2IAdapter]`, *optional*, defaults to None): | |
A list of `T2IAdapter` model instances. | |
""" | |
def __init__(self, adapters: List["T2IAdapter"]): | |
super(MultiAdapter, self).__init__() | |
self.num_adapter = len(adapters) | |
self.adapters = nn.ModuleList(adapters) | |
if len(adapters) == 0: | |
raise ValueError("Expecting at least one adapter") | |
if len(adapters) == 1: | |
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") | |
# The outputs from each adapter are added together with a weight. | |
# This means that the change in dimensions from downsampling must | |
# be the same for all adapters. Inductively, it also means the | |
# downscale_factor and total_downscale_factor must be the same for all | |
# adapters. | |
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor | |
first_adapter_downscale_factor = adapters[0].downscale_factor | |
for idx in range(1, len(adapters)): | |
if ( | |
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor | |
or adapters[idx].downscale_factor != first_adapter_downscale_factor | |
): | |
raise ValueError( | |
f"Expecting all adapters to have the same downscaling behavior, but got:\n" | |
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n" | |
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n" | |
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n" | |
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}" | |
) | |
self.total_downscale_factor = first_adapter_total_downscale_factor | |
self.downscale_factor = first_adapter_downscale_factor | |
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: | |
r""" | |
Args: | |
xs (`torch.Tensor`): | |
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1, | |
`channel` should equal to `num_adapter` * "number of channel of image". | |
adapter_weights (`List[float]`, *optional*, defaults to None): | |
List of floats representing the weight which will be multiply to each adapter's output before adding | |
them together. | |
""" | |
if adapter_weights is None: | |
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) | |
else: | |
adapter_weights = torch.tensor(adapter_weights) | |
accume_state = None | |
for x, w, adapter in zip(xs, adapter_weights, self.adapters): | |
features = adapter(x) | |
if accume_state is None: | |
accume_state = features | |
for i in range(len(accume_state)): | |
accume_state[i] = w * accume_state[i] | |
else: | |
for i in range(len(features)): | |
accume_state[i] += w * features[i] | |
return accume_state | |
def save_pretrained( | |
self, | |
save_directory: Union[str, os.PathLike], | |
is_main_process: bool = True, | |
save_function: Callable = None, | |
safe_serialization: bool = True, | |
variant: Optional[str] = None, | |
): | |
""" | |
Save a model and its configuration file to a directory, so that it can be re-loaded using the | |
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method. | |
Arguments: | |
save_directory (`str` or `os.PathLike`): | |
Directory to which to save. Will be created if it doesn't exist. | |
is_main_process (`bool`, *optional*, defaults to `True`): | |
Whether the process calling this is the main process or not. Useful when in distributed training like | |
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on | |
the main process to avoid race conditions. | |
save_function (`Callable`): | |
The function to use to save the state dictionary. Useful on distributed training like TPUs when one | |
need to replace `torch.save` by another method. Can be configured with the environment variable | |
`DIFFUSERS_SAVE_MODE`. | |
safe_serialization (`bool`, *optional*, defaults to `True`): | |
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). | |
variant (`str`, *optional*): | |
If specified, weights are saved in the format pytorch_model.<variant>.bin. | |
""" | |
idx = 0 | |
model_path_to_save = save_directory | |
for adapter in self.adapters: | |
adapter.save_pretrained( | |
model_path_to_save, | |
is_main_process=is_main_process, | |
save_function=save_function, | |
safe_serialization=safe_serialization, | |
variant=variant, | |
) | |
idx += 1 | |
model_path_to_save = model_path_to_save + f"_{idx}" | |
@classmethod | |
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): | |
r""" | |
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models. | |
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train | |
the model, you should first set it back in training mode with `model.train()`. | |
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come | |
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning | |
task. | |
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those | |
weights are discarded. | |
Parameters: | |
pretrained_model_path (`os.PathLike`): | |
A path to a *directory* containing model weights saved using | |
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. | |
torch_dtype (`str` or `torch.dtype`, *optional*): | |
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype | |
will be automatically derived from the model's weights. | |
output_loading_info(`bool`, *optional*, defaults to `False`): | |
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. | |
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): | |
A map that specifies where each submodule should go. It doesn't need to be refined to each | |
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the | |
same device. | |
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For | |
more information about each option see [designing a device | |
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). | |
max_memory (`Dict`, *optional*): | |
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each | |
GPU and the available CPU RAM if unset. | |
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): | |
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This | |
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the | |
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, | |
setting this argument to `True` will raise an error. | |
variant (`str`, *optional*): | |
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is | |
ignored when using `from_flax`. | |
use_safetensors (`bool`, *optional*, defaults to `None`): | |
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the | |
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from | |
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`. | |
""" | |
idx = 0 | |
adapters = [] | |
# load adapter and append to list until no adapter directory exists anymore | |
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained` | |
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ... | |
model_path_to_load = pretrained_model_path | |
while os.path.isdir(model_path_to_load): | |
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs) | |
adapters.append(adapter) | |
idx += 1 | |
model_path_to_load = pretrained_model_path + f"_{idx}" | |
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.") | |
if len(adapters) == 0: | |
raise ValueError( | |
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." | |
) | |
return cls(adapters) | |
class T2IAdapter(ModelMixin, ConfigMixin): | |
r""" | |
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model | |
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's | |
architecture follows the original implementation of | |
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97) | |
and | |
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library | |
implements for all the model (such as downloading or saving, etc.) | |
Parameters: | |
in_channels (`int`, *optional*, defaults to 3): | |
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale | |
image as *control image*. | |
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will | |
also determine the number of downsample blocks in the Adapter. | |
num_res_blocks (`int`, *optional*, defaults to 2): | |
Number of ResNet blocks in each downsample block. | |
downscale_factor (`int`, *optional*, defaults to 8): | |
A factor that determines the total downscale factor of the Adapter. | |
adapter_type (`str`, *optional*, defaults to `full_adapter`): | |
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
in_channels: int = 3, | |
channels: List[int] = [320, 640, 1280, 1280], | |
num_res_blocks: int = 2, | |
downscale_factor: int = 8, | |
adapter_type: str = "full_adapter", | |
): | |
super().__init__() | |
if adapter_type == "full_adapter": | |
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor) | |
elif adapter_type == "full_adapter_xl": | |
self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor) | |
elif adapter_type == "light_adapter": | |
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) | |
else: | |
raise ValueError( | |
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or " | |
"'full_adapter_xl' or 'light_adapter'." | |
) | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors, | |
each representing information extracted at a different scale from the input. The length of the list is | |
determined by the number of downsample blocks in the Adapter, as specified by the `channels` and | |
`num_res_blocks` parameters during initialization. | |
""" | |
return self.adapter(x) | |
@property | |
def total_downscale_factor(self): | |
return self.adapter.total_downscale_factor | |
@property | |
def downscale_factor(self): | |
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are | |
not evenly divisible by the downscale_factor then an exception will be raised. | |
""" | |
return self.adapter.unshuffle.downscale_factor | |
# full adapter | |
class FullAdapter(nn.Module): | |
r""" | |
See [`T2IAdapter`] for more information. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
channels: List[int] = [320, 640, 1280, 1280], | |
num_res_blocks: int = 2, | |
downscale_factor: int = 8, | |
): | |
super().__init__() | |
in_channels = in_channels * downscale_factor**2 | |
self.unshuffle = nn.PixelUnshuffle(downscale_factor) | |
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) | |
self.body = nn.ModuleList( | |
[ | |
AdapterBlock(channels[0], channels[0], num_res_blocks), | |
*[ | |
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True) | |
for i in range(1, len(channels)) | |
], | |
] | |
) | |
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
This method processes the input tensor `x` through the FullAdapter model and performs operations including | |
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each | |
capturing information at a different stage of processing within the FullAdapter model. The number of feature | |
tensors in the list is determined by the number of downsample blocks specified during initialization. | |
""" | |
x = self.unshuffle(x) | |
x = self.conv_in(x) | |
features = [] | |
for block in self.body: | |
x = block(x) | |
features.append(x) | |
return features | |
class FullAdapterXL(nn.Module): | |
r""" | |
See [`T2IAdapter`] for more information. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
channels: List[int] = [320, 640, 1280, 1280], | |
num_res_blocks: int = 2, | |
downscale_factor: int = 16, | |
): | |
super().__init__() | |
in_channels = in_channels * downscale_factor**2 | |
self.unshuffle = nn.PixelUnshuffle(downscale_factor) | |
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) | |
self.body = [] | |
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32] | |
for i in range(len(channels)): | |
if i == 1: | |
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks)) | |
elif i == 2: | |
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)) | |
else: | |
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks)) | |
self.body = nn.ModuleList(self.body) | |
# XL has only one downsampling AdapterBlock. | |
self.total_downscale_factor = downscale_factor * 2 | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations | |
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors. | |
""" | |
x = self.unshuffle(x) | |
x = self.conv_in(x) | |
features = [] | |
for block in self.body: | |
x = block(x) | |
features.append(x) | |
return features | |
class AdapterBlock(nn.Module): | |
r""" | |
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and | |
`FullAdapterXL` models. | |
Parameters: | |
in_channels (`int`): | |
Number of channels of AdapterBlock's input. | |
out_channels (`int`): | |
Number of channels of AdapterBlock's output. | |
num_res_blocks (`int`): | |
Number of ResNet blocks in the AdapterBlock. | |
down (`bool`, *optional*, defaults to `False`): | |
Whether to perform downsampling on AdapterBlock's input. | |
""" | |
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): | |
super().__init__() | |
self.downsample = None | |
if down: | |
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) | |
self.in_conv = None | |
if in_channels != out_channels: | |
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
self.resnets = nn.Sequential( | |
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r""" | |
This method takes tensor x as input and performs operations downsampling and convolutional layers if the | |
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of | |
residual blocks to the input tensor. | |
""" | |
if self.downsample is not None: | |
x = self.downsample(x) | |
if self.in_conv is not None: | |
x = self.in_conv(x) | |
x = self.resnets(x) | |
return x | |
class AdapterResnetBlock(nn.Module): | |
r""" | |
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. | |
Parameters: | |
channels (`int`): | |
Number of channels of AdapterResnetBlock's input and output. | |
""" | |
def __init__(self, channels: int): | |
super().__init__() | |
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.act = nn.ReLU() | |
self.block2 = nn.Conv2d(channels, channels, kernel_size=1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r""" | |
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional | |
layer on the input tensor. It returns addition with the input tensor. | |
""" | |
h = self.act(self.block1(x)) | |
h = self.block2(h) | |
return h + x | |
# light adapter | |
class LightAdapter(nn.Module): | |
r""" | |
See [`T2IAdapter`] for more information. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
channels: List[int] = [320, 640, 1280], | |
num_res_blocks: int = 4, | |
downscale_factor: int = 8, | |
): | |
super().__init__() | |
in_channels = in_channels * downscale_factor**2 | |
self.unshuffle = nn.PixelUnshuffle(downscale_factor) | |
self.body = nn.ModuleList( | |
[ | |
LightAdapterBlock(in_channels, channels[0], num_res_blocks), | |
*[ | |
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True) | |
for i in range(len(channels) - 1) | |
], | |
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True), | |
] | |
) | |
self.total_downscale_factor = downscale_factor * (2 ** len(channels)) | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each | |
feature tensor corresponds to a different level of processing within the LightAdapter. | |
""" | |
x = self.unshuffle(x) | |
features = [] | |
for block in self.body: | |
x = block(x) | |
features.append(x) | |
return features | |
class LightAdapterBlock(nn.Module): | |
r""" | |
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the | |
`LightAdapter` model. | |
Parameters: | |
in_channels (`int`): | |
Number of channels of LightAdapterBlock's input. | |
out_channels (`int`): | |
Number of channels of LightAdapterBlock's output. | |
num_res_blocks (`int`): | |
Number of LightAdapterResnetBlocks in the LightAdapterBlock. | |
down (`bool`, *optional*, defaults to `False`): | |
Whether to perform downsampling on LightAdapterBlock's input. | |
""" | |
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): | |
super().__init__() | |
mid_channels = out_channels // 4 | |
self.downsample = None | |
if down: | |
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) | |
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) | |
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) | |
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r""" | |
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution | |
layer, a sequence of residual blocks, and out convolutional layer. | |
""" | |
if self.downsample is not None: | |
x = self.downsample(x) | |
x = self.in_conv(x) | |
x = self.resnets(x) | |
x = self.out_conv(x) | |
return x | |
class LightAdapterResnetBlock(nn.Module): | |
""" | |
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different | |
architecture than `AdapterResnetBlock`. | |
Parameters: | |
channels (`int`): | |
Number of channels of LightAdapterResnetBlock's input and output. | |
""" | |
def __init__(self, channels: int): | |
super().__init__() | |
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.act = nn.ReLU() | |
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r""" | |
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and | |
another convolutional layer and adds it to input tensor. | |
""" | |
h = self.act(self.block1(x)) | |
h = self.block2(h) | |
return h + x | |
import inspect | |
from importlib import import_module | |
from typing import Callable, Optional, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..image_processor import IPAdapterMaskProcessor | |
from ..utils import USE_PEFT_BACKEND, deprecate, logging | |
from ..utils.import_utils import is_xformers_available | |
from ..utils.torch_utils import maybe_allow_in_graph | |
from .lora import LoRACompatibleLinear, LoRALinearLayer | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
@maybe_allow_in_graph | |
class Attention(nn.Module): | |
r""" | |
A cross attention layer. | |
Parameters: | |
query_dim (`int`): | |
The number of channels in the query. | |
cross_attention_dim (`int`, *optional*): | |
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. | |
heads (`int`, *optional*, defaults to 8): | |
The number of heads to use for multi-head attention. | |
dim_head (`int`, *optional*, defaults to 64): | |
The number of channels in each head. | |
dropout (`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
bias (`bool`, *optional*, defaults to False): | |
Set to `True` for the query, key, and value linear layers to contain a bias parameter. | |
upcast_attention (`bool`, *optional*, defaults to False): | |
Set to `True` to upcast the attention computation to `float32`. | |
upcast_softmax (`bool`, *optional*, defaults to False): | |
Set to `True` to upcast the softmax computation to `float32`. | |
cross_attention_norm (`str`, *optional*, defaults to `None`): | |
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. | |
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups to use for the group norm in the cross attention. | |
added_kv_proj_dim (`int`, *optional*, defaults to `None`): | |
The number of channels to use for the added key and value projections. If `None`, no projection is used. | |
norm_num_groups (`int`, *optional*, defaults to `None`): | |
The number of groups to use for the group norm in the attention. | |
spatial_norm_dim (`int`, *optional*, defaults to `None`): | |
The number of channels to use for the spatial normalization. | |
out_bias (`bool`, *optional*, defaults to `True`): | |
Set to `True` to use a bias in the output linear layer. | |
scale_qk (`bool`, *optional*, defaults to `True`): | |
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. | |
only_cross_attention (`bool`, *optional*, defaults to `False`): | |
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if | |
`added_kv_proj_dim` is not `None`. | |
eps (`float`, *optional*, defaults to 1e-5): | |
An additional value added to the denominator in group normalization that is used for numerical stability. | |
rescale_output_factor (`float`, *optional*, defaults to 1.0): | |
A factor to rescale the output by dividing it with this value. | |
residual_connection (`bool`, *optional*, defaults to `False`): | |
Set to `True` to add the residual connection to the output. | |
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): | |
Set to `True` if the attention block is loaded from a deprecated state dict. | |
processor (`AttnProcessor`, *optional*, defaults to `None`): | |
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and | |
`AttnProcessor` otherwise. | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
cross_attention_dim: Optional[int] = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
cross_attention_norm: Optional[str] = None, | |
cross_attention_norm_num_groups: int = 32, | |
added_kv_proj_dim: Optional[int] = None, | |
norm_num_groups: Optional[int] = None, | |
spatial_norm_dim: Optional[int] = None, | |
out_bias: bool = True, | |
scale_qk: bool = True, | |
only_cross_attention: bool = False, | |
eps: float = 1e-5, | |
rescale_output_factor: float = 1.0, | |
residual_connection: bool = False, | |
_from_deprecated_attn_block: bool = False, | |
processor: Optional["AttnProcessor"] = None, | |
out_dim: int = None, | |
): | |
super().__init__() | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.query_dim = query_dim | |
self.use_bias = bias | |
self.is_cross_attention = cross_attention_dim is not None | |
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.rescale_output_factor = rescale_output_factor | |
self.residual_connection = residual_connection | |
self.dropout = dropout | |
self.fused_projections = False | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
# we make use of this private variable to know whether this class is loaded | |
# with an deprecated state dict so that we can convert it on the fly | |
self._from_deprecated_attn_block = _from_deprecated_attn_block | |
self.scale_qk = scale_qk | |
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
# for slice_size > 0 the attention score computation | |
# is split across the batch axis to save memory | |
# You can set slice_size with `set_attention_slice` | |
self.sliceable_head_dim = heads | |
self.added_kv_proj_dim = added_kv_proj_dim | |
self.only_cross_attention = only_cross_attention | |
if self.added_kv_proj_dim is None and self.only_cross_attention: | |
raise ValueError( | |
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." | |
) | |
if norm_num_groups is not None: | |
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) | |
else: | |
self.group_norm = None | |
if spatial_norm_dim is not None: | |
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) | |
else: | |
self.spatial_norm = None | |
if cross_attention_norm is None: | |
self.norm_cross = None | |
elif cross_attention_norm == "layer_norm": | |
self.norm_cross = nn.LayerNorm(self.cross_attention_dim) | |
elif cross_attention_norm == "group_norm": | |
if self.added_kv_proj_dim is not None: | |
# The given `encoder_hidden_states` are initially of shape | |
# (batch_size, seq_len, added_kv_proj_dim) before being projected | |
# to (batch_size, seq_len, cross_attention_dim). The norm is applied | |
# before the projection, so we need to use `added_kv_proj_dim` as | |
# the number of channels for the group norm. | |
norm_cross_num_channels = added_kv_proj_dim | |
else: | |
norm_cross_num_channels = self.cross_attention_dim | |
self.norm_cross = nn.GroupNorm( | |
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True | |
) | |
else: | |
raise ValueError( | |
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" | |
) | |
if USE_PEFT_BACKEND: | |
linear_cls = nn.Linear | |
else: | |
linear_cls = LoRACompatibleLinear | |
self.linear_cls = linear_cls | |
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) | |
if not self.only_cross_attention: | |
# only relevant for the `AddedKVProcessor` classes | |
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) | |
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) | |
else: | |
self.to_k = None | |
self.to_v = None | |
if self.added_kv_proj_dim is not None: | |
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) | |
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) | |
self.to_out = nn.ModuleList([]) | |
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) | |
self.to_out.append(nn.Dropout(dropout)) | |
# set attention processor | |
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses | |
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention | |
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 | |
if processor is None: | |
processor = ( | |
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | |
) | |
self.set_processor(processor) | |
def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None | |
) -> None: | |
r""" | |
Set whether to use memory efficient attention from `xformers` or not. | |
Args: | |
use_memory_efficient_attention_xformers (`bool`): | |
Whether to use memory efficient attention from `xformers` or not. | |
attention_op (`Callable`, *optional*): | |
The attention operation to use. Defaults to `None` which uses the default attention operation from | |
`xformers`. | |
""" | |
is_lora = hasattr(self, "processor") and isinstance( | |
self.processor, | |
LORA_ATTENTION_PROCESSORS, | |
) | |
is_custom_diffusion = hasattr(self, "processor") and isinstance( | |
self.processor, | |
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), | |
) | |
is_added_kv_processor = hasattr(self, "processor") and isinstance( | |
self.processor, | |
( | |
AttnAddedKVProcessor, | |
AttnAddedKVProcessor2_0, | |
SlicedAttnAddedKVProcessor, | |
XFormersAttnAddedKVProcessor, | |
LoRAAttnAddedKVProcessor, | |
), | |
) | |
if use_memory_efficient_attention_xformers: | |
if is_added_kv_processor and (is_lora or is_custom_diffusion): | |
raise NotImplementedError( | |
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" | |
) | |
if not is_xformers_available(): | |
raise ModuleNotFoundError( | |
( | |
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" | |
" xformers" | |
), | |
name="xformers", | |
) | |
elif not torch.cuda.is_available(): | |
raise ValueError( | |
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" | |
" only available for GPU " | |
) | |
else: | |
try: | |
# Make sure we can run the memory efficient attention | |
_ = xformers.ops.memory_efficient_attention( | |
torch.randn((1, 2, 40), device="cuda"), | |
torch.randn((1, 2, 40), device="cuda"), | |
torch.randn((1, 2, 40), device="cuda"), | |
) | |
except Exception as e: | |
raise e | |
if is_lora: | |
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers | |
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? | |
processor = LoRAXFormersAttnProcessor( | |
hidden_size=self.processor.hidden_size, | |
cross_attention_dim=self.processor.cross_attention_dim, | |
rank=self.processor.rank, | |
attention_op=attention_op, | |
) | |
processor.load_state_dict(self.processor.state_dict()) | |
processor.to(self.processor.to_q_lora.up.weight.device) | |
elif is_custom_diffusion: | |
processor = CustomDiffusionXFormersAttnProcessor( | |
train_kv=self.processor.train_kv, | |
train_q_out=self.processor.train_q_out, | |
hidden_size=self.processor.hidden_size, | |
cross_attention_dim=self.processor.cross_attention_dim, | |
attention_op=attention_op, | |
) | |
processor.load_state_dict(self.processor.state_dict()) | |
if hasattr(self.processor, "to_k_custom_diffusion"): | |
processor.to(self.processor.to_k_custom_diffusion.weight.device) | |
elif is_added_kv_processor: | |
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP | |
# which uses this type of cross attention ONLY because the attention mask of format | |
# [0, ..., -10.000, ..., 0, ...,] is not supported | |
# throw warning | |
logger.info( | |
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." | |
) | |
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) | |
else: | |
processor = XFormersAttnProcessor(attention_op=attention_op) | |
else: | |
if is_lora: | |
attn_processor_class = ( | |
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | |
) | |
processor = attn_processor_class( | |
hidden_size=self.processor.hidden_size, | |
cross_attention_dim=self.processor.cross_attention_dim, | |
rank=self.processor.rank, | |
) | |
processor.load_state_dict(self.processor.state_dict()) | |
processor.to(self.processor.to_q_lora.up.weight.device) | |
elif is_custom_diffusion: | |
attn_processor_class = ( | |
CustomDiffusionAttnProcessor2_0 | |
if hasattr(F, "scaled_dot_product_attention") | |
else CustomDiffusionAttnProcessor | |
) | |
processor = attn_processor_class( | |
train_kv=self.processor.train_kv, | |
train_q_out=self.processor.train_q_out, | |
hidden_size=self.processor.hidden_size, | |
cross_attention_dim=self.processor.cross_attention_dim, | |
) | |
processor.load_state_dict(self.processor.state_dict()) | |
if hasattr(self.processor, "to_k_custom_diffusion"): | |
processor.to(self.processor.to_k_custom_diffusion.weight.device) | |
else: | |
# set attention processor | |
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses | |
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention | |
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 | |
processor = ( | |
AttnProcessor2_0() | |
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk | |
else AttnProcessor() | |
) | |
self.set_processor(processor) | |
def set_attention_slice(self, slice_size: int) -> None: | |
r""" | |
Set the slice size for attention computation. | |
Args: | |
slice_size (`int`): | |
The slice size for attention computation. | |
""" | |
if slice_size is not None and slice_size > self.sliceable_head_dim: | |
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") | |
if slice_size is not None and self.added_kv_proj_dim is not None: | |
processor = SlicedAttnAddedKVProcessor(slice_size) | |
elif slice_size is not None: | |
processor = SlicedAttnProcessor(slice_size) | |
elif self.added_kv_proj_dim is not None: | |
processor = AttnAddedKVProcessor() | |
else: | |
# set attention processor | |
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses | |
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention | |
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 | |
processor = ( | |
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | |
) | |
self.set_processor(processor) | |
def set_processor(self, processor: "AttnProcessor") -> None: | |
r""" | |
Set the attention processor to use. | |
Args: | |
processor (`AttnProcessor`): | |
The attention processor to use. | |
""" | |
# if current processor is in `self._modules` and if passed `processor` is not, we need to | |
# pop `processor` from `self._modules` | |
if ( | |
hasattr(self, "processor") | |
and isinstance(self.processor, torch.nn.Module) | |
and not isinstance(processor, torch.nn.Module) | |
): | |
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") | |
self._modules.pop("processor") | |
self.processor = processor | |
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": | |
r""" | |
Get the attention processor in use. | |
Args: | |
return_deprecated_lora (`bool`, *optional*, defaults to `False`): | |
Set to `True` to return the deprecated LoRA attention processor. | |
Returns: | |
"AttentionProcessor": The attention processor in use. | |
""" | |
if not return_deprecated_lora: | |
return self.processor | |
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible | |
# serialization format for LoRA Attention Processors. It should be deleted once the integration | |
# with PEFT is completed. | |
is_lora_activated = { | |
name: module.lora_layer is not None | |
for name, module in self.named_modules() | |
if hasattr(module, "lora_layer") | |
} | |
# 1. if no layer has a LoRA activated we can return the processor as usual | |
if not any(is_lora_activated.values()): | |
return self.processor | |
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj` | |
is_lora_activated.pop("add_k_proj", None) | |
is_lora_activated.pop("add_v_proj", None) | |
# 2. else it is not posssible that only some layers have LoRA activated | |
if not all(is_lora_activated.values()): | |
raise ValueError( | |
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" | |
) | |
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor | |
non_lora_processor_cls_name = self.processor.__class__.__name__ | |
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) | |
hidden_size = self.inner_dim | |
# now create a LoRA attention processor from the LoRA layers | |
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: | |
kwargs = { | |
"cross_attention_dim": self.cross_attention_dim, | |
"rank": self.to_q.lora_layer.rank, | |
"network_alpha": self.to_q.lora_layer.network_alpha, | |
"q_rank": self.to_q.lora_layer.rank, | |
"q_hidden_size": self.to_q.lora_layer.out_features, | |
"k_rank": self.to_k.lora_layer.rank, | |
"k_hidden_size": self.to_k.lora_layer.out_features, | |
"v_rank": self.to_v.lora_layer.rank, | |
"v_hidden_size": self.to_v.lora_layer.out_features, | |
"out_rank": self.to_out[0].lora_layer.rank, | |
"out_hidden_size": self.to_out[0].lora_layer.out_features, | |
} | |
if hasattr(self.processor, "attention_op"): | |
kwargs["attention_op"] = self.processor.attention_op | |
lora_processor = lora_processor_cls(hidden_size, **kwargs) | |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) | |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) | |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) | |
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) | |
elif lora_processor_cls == LoRAAttnAddedKVProcessor: | |
lora_processor = lora_processor_cls( | |
hidden_size, | |
cross_attention_dim=self.add_k_proj.weight.shape[0], | |
rank=self.to_q.lora_layer.rank, | |
network_alpha=self.to_q.lora_layer.network_alpha, | |
) | |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) | |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) | |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) | |
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) | |
# only save if used | |
if self.add_k_proj.lora_layer is not None: | |
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) | |
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) | |
else: | |
lora_processor.add_k_proj_lora = None | |
lora_processor.add_v_proj_lora = None | |
else: | |
raise ValueError(f"{lora_processor_cls} does not exist.") | |
return lora_processor | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
**cross_attention_kwargs, | |
) -> torch.Tensor: | |
r""" | |
The forward method of the `Attention` class. | |
Args: | |
hidden_states (`torch.Tensor`): | |
The hidden states of the query. | |
encoder_hidden_states (`torch.Tensor`, *optional*): | |
The hidden states of the encoder. | |
attention_mask (`torch.Tensor`, *optional*): | |
The attention mask to use. If `None`, no mask is applied. | |
**cross_attention_kwargs: | |
Additional keyword arguments to pass along to the cross attention. | |
Returns: | |
`torch.Tensor`: The output of the attention layer. | |
""" | |
# The `Attention` class can call different attention processors / attention functions | |
# here we simply pass along all tensors to the selected processor class | |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty | |
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) | |
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] | |
if len(unused_kwargs) > 0: | |
logger.warning( | |
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." | |
) | |
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} | |
return self.processor( | |
self, | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: | |
r""" | |
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` | |
is the number of heads initialized while constructing the `Attention` class. | |
Args: | |
tensor (`torch.Tensor`): The tensor to reshape. | |
Returns: | |
`torch.Tensor`: The reshaped tensor. | |
""" | |
head_size = self.heads | |
batch_size, seq_len, dim = tensor.shape | |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) | |
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) | |
return tensor | |
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: | |
r""" | |
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is | |
the number of heads initialized while constructing the `Attention` class. | |
Args: | |
tensor (`torch.Tensor`): The tensor to reshape. | |
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is | |
reshaped to `[batch_size * heads, seq_len, dim // heads]`. | |
Returns: | |
`torch.Tensor`: The reshaped tensor. | |
""" | |
head_size = self.heads | |
if tensor.ndim == 3: | |
batch_size, seq_len, dim = tensor.shape | |
extra_dim = 1 | |
else: | |
batch_size, extra_dim, seq_len, dim = tensor.shape | |
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) | |
tensor = tensor.permute(0, 2, 1, 3) | |
if out_dim == 3: | |
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) | |
return tensor | |
def get_attention_scores( | |
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None | |
) -> torch.Tensor: | |
r""" | |
Compute the attention scores. | |
Args: | |
query (`torch.Tensor`): The query tensor. | |
key (`torch.Tensor`): The key tensor. | |
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. | |
Returns: | |
`torch.Tensor`: The attention probabilities/scores. | |
""" | |
dtype = query.dtype | |
if self.upcast_attention: | |
query = query.float() | |
key = key.float() | |
if attention_mask is None: | |
baddbmm_input = torch.empty( | |
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device | |
) | |
beta = 0 | |
else: | |
baddbmm_input = attention_mask | |
beta = 1 | |
attention_scores = torch.baddbmm( | |
baddbmm_input, | |
query, | |
key.transpose(-1, -2), | |
beta=beta, | |
alpha=self.scale, | |
) | |
del baddbmm_input | |
if self.upcast_softmax: | |
attention_scores = attention_scores.float() | |
attention_probs = attention_scores.softmax(dim=-1) | |
del attention_scores | |
attention_probs = attention_probs.to(dtype) | |
return attention_probs | |
def prepare_attention_mask( | |
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 | |
) -> torch.Tensor: | |
r""" | |
Prepare the attention mask for the attention computation. | |
Args: | |
attention_mask (`torch.Tensor`): | |
The attention mask to prepare. | |
target_length (`int`): | |
The target length of the attention mask. This is the length of the attention mask after padding. | |
batch_size (`int`): | |
The batch size, which is used to repeat the attention mask. | |
out_dim (`int`, *optional*, defaults to `3`): | |
The output dimension of the attention mask. Can be either `3` or `4`. | |
Returns: | |
`torch.Tensor`: The prepared attention mask. | |
""" | |
head_size = self.heads | |
if attention_mask is None: | |
return attention_mask | |
current_length: int = attention_mask.shape[-1] | |
if current_length != target_length: | |
if attention_mask.device.type == "mps": | |
# HACK: MPS: Does not support padding by greater than dimension of input tensor. | |
# Instead, we can manually construct the padding tensor. | |
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) | |
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) | |
attention_mask = torch.cat([attention_mask, padding], dim=2) | |
else: | |
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask: | |
# we want to instead pad by (0, remaining_length), where remaining_length is: | |
# remaining_length: int = target_length - current_length | |
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding | |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) | |
if out_dim == 3: | |
if attention_mask.shape[0] < batch_size * head_size: | |
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | |
elif out_dim == 4: | |
attention_mask = attention_mask.unsqueeze(1) | |
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) | |
return attention_mask | |
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: | |
r""" | |
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the | |
`Attention` class. | |
Args: | |
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. | |
Returns: | |
`torch.Tensor`: The normalized encoder hidden states. | |
""" | |
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" | |
if isinstance(self.norm_cross, nn.LayerNorm): | |
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | |
elif isinstance(self.norm_cross, nn.GroupNorm): | |
# Group norm norms along the channels dimension and expects | |
# input to be in the shape of (N, C, *). In this case, we want | |
# to norm along the hidden dimension, so we need to move | |
# (batch_size, sequence_length, hidden_size) -> | |
# (batch_size, hidden_size, sequence_length) | |
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | |
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | |
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | |
else: | |
assert False | |
return encoder_hidden_states | |
@torch.no_grad() | |
def fuse_projections(self, fuse=True): | |
device = self.to_q.weight.data.device | |
dtype = self.to_q.weight.data.dtype | |
if not self.is_cross_attention: | |
# fetch weight matrices. | |
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
# create a new single projection layer and copy over the weights. | |
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
self.to_qkv.weight.copy_(concatenated_weights) | |
if self.use_bias: | |
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | |
self.to_qkv.bias.copy_(concatenated_bias) | |
else: | |
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) | |
in_features = concatenated_weights.shape[1] | |
out_features = concatenated_weights.shape[0] | |
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | |
self.to_kv.weight.copy_(concatenated_weights) | |
if self.use_bias: | |
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) | |
self.to_kv.bias.copy_(concatenated_bias) | |
self.fused_projections = fuse | |
class AttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.Tensor: | |
residual = hidden_states | |
args = () if USE_PEFT_BACKEND else (scale,) | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states, *args) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states, *args) | |
value = attn.to_v(encoder_hidden_states, *args) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class CustomDiffusionAttnProcessor(nn.Module): | |
r""" | |
Processor for implementing attention for the Custom Diffusion method. | |
Args: | |
train_kv (`bool`, defaults to `True`): | |
Whether to newly train the key and value matrices corresponding to the text features. | |
train_q_out (`bool`, defaults to `True`): | |
Whether to newly train query matrices corresponding to the latent image features. | |
hidden_size (`int`, *optional*, defaults to `None`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`, *optional*, defaults to `None`): | |
The number of channels in the `encoder_hidden_states`. | |
out_bias (`bool`, defaults to `True`): | |
Whether to include the bias parameter in `train_q_out`. | |
dropout (`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
""" | |
def __init__( | |
self, | |
train_kv: bool = True, | |
train_q_out: bool = True, | |
hidden_size: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
out_bias: bool = True, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
self.train_kv = train_kv | |
self.train_q_out = train_q_out | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
# `_custom_diffusion` id for easy serialization and loading. | |
if self.train_kv: | |
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
if self.train_q_out: | |
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) | |
self.to_out_custom_diffusion = nn.ModuleList([]) | |
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) | |
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.Tensor: | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if self.train_q_out: | |
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) | |
else: | |
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) | |
if encoder_hidden_states is None: | |
crossattn = False | |
encoder_hidden_states = hidden_states | |
else: | |
crossattn = True | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
if self.train_kv: | |
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) | |
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) | |
key = key.to(attn.to_q.weight.dtype) | |
value = value.to(attn.to_q.weight.dtype) | |
else: | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
if crossattn: | |
detach = torch.ones_like(key) | |
detach[:, :1, :] = detach[:, :1, :] * 0.0 | |
key = detach * key + (1 - detach) * key.detach() | |
value = detach * value + (1 - detach) * value.detach() | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
if self.train_q_out: | |
# linear proj | |
hidden_states = self.to_out_custom_diffusion[0](hidden_states) | |
# dropout | |
hidden_states = self.to_out_custom_diffusion[1](hidden_states) | |
else: | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class AttnAddedKVProcessor: | |
r""" | |
Processor for performing attention-related computations with extra learnable key and value matrices for the text | |
encoder. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.Tensor: | |
residual = hidden_states | |
args = () if USE_PEFT_BACKEND else (scale,) | |
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states, *args) | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) | |
if not attn.only_cross_attention: | |
key = attn.to_k(hidden_states, *args) | |
value = attn.to_v(hidden_states, *args) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | |
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | |
else: | |
key = encoder_hidden_states_key_proj | |
value = encoder_hidden_states_value_proj | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class AttnAddedKVProcessor2_0: | |
r""" | |
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra | |
learnable key and value matrices for the text encoder. | |
""" | |
def __init__(self): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.Tensor: | |
residual = hidden_states | |
args = () if USE_PEFT_BACKEND else (scale,) | |
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states, *args) | |
query = attn.head_to_batch_dim(query, out_dim=4) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) | |
if not attn.only_cross_attention: | |
key = attn.to_k(hidden_states, *args) | |
value = attn.to_v(hidden_states, *args) | |
key = attn.head_to_batch_dim(key, out_dim=4) | |
value = attn.head_to_batch_dim(value, out_dim=4) | |
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | |
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | |
else: | |
key = encoder_hidden_states_key_proj | |
value = encoder_hidden_states_value_proj | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class XFormersAttnAddedKVProcessor: | |
r""" | |
Processor for implementing memory efficient attention using xFormers. | |
Args: | |
attention_op (`Callable`, *optional*, defaults to `None`): | |
The base | |
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to | |
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best | |
operator. | |
""" | |
def __init__(self, attention_op: Optional[Callable] = None): | |
self.attention_op = attention_op | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.Tensor: | |
residual = hidden_states | |
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) | |
if not attn.only_cross_attention: | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | |
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | |
else: | |
key = encoder_hidden_states_key_proj | |
value = encoder_hidden_states_value_proj | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class XFormersAttnProcessor: | |
r""" | |
Processor for implementing memory efficient attention using xFormers. | |
Args: | |
attention_op (`Callable`, *optional*, defaults to `None`): | |
The base | |
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to | |
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best | |
operator. | |
""" | |
def __init__(self, attention_op: Optional[Callable] = None): | |
self.attention_op = attention_op | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
residual = hidden_states | |
args = () if USE_PEFT_BACKEND else (scale,) | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, key_tokens, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states, *args) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states, *args) | |
value = attn.to_v(encoder_hidden_states, *args) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class AttnProcessor2_0: | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
""" | |
def __init__(self): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
args = () if USE_PEFT_BACKEND else (scale,) | |
query = attn.to_q(hidden_states, *args) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states, *args) | |
value = attn.to_v(encoder_hidden_states, *args) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class FusedAttnProcessor2_0: | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, | |
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is currently 🧪 experimental in nature and can change in future. | |
</Tip> | |
""" | |
def __init__(self): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
args = () if USE_PEFT_BACKEND else (scale,) | |
if encoder_hidden_states is None: | |
qkv = attn.to_qkv(hidden_states, *args) | |
split_size = qkv.shape[-1] // 3 | |
query, key, value = torch.split(qkv, split_size, dim=-1) | |
else: | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
query = attn.to_q(hidden_states, *args) | |
kv = attn.to_kv(encoder_hidden_states, *args) | |
split_size = kv.shape[-1] // 2 | |
key, value = torch.split(kv, split_size, dim=-1) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class CustomDiffusionXFormersAttnProcessor(nn.Module): | |
r""" | |
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. | |
Args: | |
train_kv (`bool`, defaults to `True`): | |
Whether to newly train the key and value matrices corresponding to the text features. | |
train_q_out (`bool`, defaults to `True`): | |
Whether to newly train query matrices corresponding to the latent image features. | |
hidden_size (`int`, *optional*, defaults to `None`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`, *optional*, defaults to `None`): | |
The number of channels in the `encoder_hidden_states`. | |
out_bias (`bool`, defaults to `True`): | |
Whether to include the bias parameter in `train_q_out`. | |
dropout (`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
attention_op (`Callable`, *optional*, defaults to `None`): | |
The base | |
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use | |
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. | |
""" | |
def __init__( | |
self, | |
train_kv: bool = True, | |
train_q_out: bool = False, | |
hidden_size: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
out_bias: bool = True, | |
dropout: float = 0.0, | |
attention_op: Optional[Callable] = None, | |
): | |
super().__init__() | |
self.train_kv = train_kv | |
self.train_q_out = train_q_out | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.attention_op = attention_op | |
# `_custom_diffusion` id for easy serialization and loading. | |
if self.train_kv: | |
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
if self.train_q_out: | |
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) | |
self.to_out_custom_diffusion = nn.ModuleList([]) | |
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) | |
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if self.train_q_out: | |
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) | |
else: | |
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) | |
if encoder_hidden_states is None: | |
crossattn = False | |
encoder_hidden_states = hidden_states | |
else: | |
crossattn = True | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
if self.train_kv: | |
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) | |
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) | |
key = key.to(attn.to_q.weight.dtype) | |
value = value.to(attn.to_q.weight.dtype) | |
else: | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
if crossattn: | |
detach = torch.ones_like(key) | |
detach[:, :1, :] = detach[:, :1, :] * 0.0 | |
key = detach * key + (1 - detach) * key.detach() | |
value = detach * value + (1 - detach) * value.detach() | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
if self.train_q_out: | |
# linear proj | |
hidden_states = self.to_out_custom_diffusion[0](hidden_states) | |
# dropout | |
hidden_states = self.to_out_custom_diffusion[1](hidden_states) | |
else: | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class CustomDiffusionAttnProcessor2_0(nn.Module): | |
r""" | |
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled | |
dot-product attention. | |
Args: | |
train_kv (`bool`, defaults to `True`): | |
Whether to newly train the key and value matrices corresponding to the text features. | |
train_q_out (`bool`, defaults to `True`): | |
Whether to newly train query matrices corresponding to the latent image features. | |
hidden_size (`int`, *optional*, defaults to `None`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`, *optional*, defaults to `None`): | |
The number of channels in the `encoder_hidden_states`. | |
out_bias (`bool`, defaults to `True`): | |
Whether to include the bias parameter in `train_q_out`. | |
dropout (`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
""" | |
def __init__( | |
self, | |
train_kv: bool = True, | |
train_q_out: bool = True, | |
hidden_size: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
out_bias: bool = True, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
self.train_kv = train_kv | |
self.train_q_out = train_q_out | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
# `_custom_diffusion` id for easy serialization and loading. | |
if self.train_kv: | |
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
if self.train_q_out: | |
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) | |
self.to_out_custom_diffusion = nn.ModuleList([]) | |
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) | |
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if self.train_q_out: | |
query = self.to_q_custom_diffusion(hidden_states) | |
else: | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
crossattn = False | |
encoder_hidden_states = hidden_states | |
else: | |
crossattn = True | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
if self.train_kv: | |
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) | |
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) | |
key = key.to(attn.to_q.weight.dtype) | |
value = value.to(attn.to_q.weight.dtype) | |
else: | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
if crossattn: | |
detach = torch.ones_like(key) | |
detach[:, :1, :] = detach[:, :1, :] * 0.0 | |
key = detach * key + (1 - detach) * key.detach() | |
value = detach * value + (1 - detach) * value.detach() | |
inner_dim = hidden_states.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
if self.train_q_out: | |
# linear proj | |
hidden_states = self.to_out_custom_diffusion[0](hidden_states) | |
# dropout | |
hidden_states = self.to_out_custom_diffusion[1](hidden_states) | |
else: | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class SlicedAttnProcessor: | |
r""" | |
Processor for implementing sliced attention. | |
Args: | |
slice_size (`int`, *optional*): | |
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and | |
`attention_head_dim` must be a multiple of the `slice_size`. | |
""" | |
def __init__(self, slice_size: int): | |
self.slice_size = slice_size | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
residual = hidden_states | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
dim = query.shape[-1] | |
query = attn.head_to_batch_dim(query) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
batch_size_attention, query_tokens, _ = query.shape | |
hidden_states = torch.zeros( | |
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype | |
) | |
for i in range(batch_size_attention // self.slice_size): | |
start_idx = i * self.slice_size | |
end_idx = (i + 1) * self.slice_size | |
query_slice = query[start_idx:end_idx] | |
key_slice = key[start_idx:end_idx] | |
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class SlicedAttnAddedKVProcessor: | |
r""" | |
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. | |
Args: | |
slice_size (`int`, *optional*): | |
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and | |
`attention_head_dim` must be a multiple of the `slice_size`. | |
""" | |
def __init__(self, slice_size): | |
self.slice_size = slice_size | |
def __call__( | |
self, | |
attn: "Attention", | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
dim = query.shape[-1] | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) | |
if not attn.only_cross_attention: | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | |
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | |
else: | |
key = encoder_hidden_states_key_proj | |
value = encoder_hidden_states_value_proj | |
batch_size_attention, query_tokens, _ = query.shape | |
hidden_states = torch.zeros( | |
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype | |
) | |
for i in range(batch_size_attention // self.slice_size): | |
start_idx = i * self.slice_size | |
end_idx = (i + 1) * self.slice_size | |
query_slice = query[start_idx:end_idx] | |
key_slice = key[start_idx:end_idx] | |
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class SpatialNorm(nn.Module): | |
""" | |
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. | |
Args: | |
f_channels (`int`): | |
The number of channels for input to group normalization layer, and output of the spatial norm layer. | |
zq_channels (`int`): | |
The number of channels for the quantized vector as described in the paper. | |
""" | |
def __init__( | |
self, | |
f_channels: int, | |
zq_channels: int, | |
): | |
super().__init__() | |
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) | |
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) | |
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) | |
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor: | |
f_size = f.shape[-2:] | |
zq = F.interpolate(zq, size=f_size, mode="nearest") | |
norm_f = self.norm_layer(f) | |
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) | |
return new_f | |
class LoRAAttnProcessor(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int, | |
cross_attention_dim: Optional[int] = None, | |
rank: int = 4, | |
network_alpha: Optional[int] = None, | |
**kwargs, | |
): | |
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`." | |
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False) | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.rank = rank | |
q_rank = kwargs.pop("q_rank", None) | |
q_hidden_size = kwargs.pop("q_hidden_size", None) | |
q_rank = q_rank if q_rank is not None else rank | |
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size | |
v_rank = kwargs.pop("v_rank", None) | |
v_hidden_size = kwargs.pop("v_hidden_size", None) | |
v_rank = v_rank if v_rank is not None else rank | |
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size | |
out_rank = kwargs.pop("out_rank", None) | |
out_hidden_size = kwargs.pop("out_hidden_size", None) | |
out_rank = out_rank if out_rank is not None else rank | |
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size | |
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) | |
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | |
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) | |
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) | |
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: | |
self_cls_name = self.__class__.__name__ | |
deprecate( | |
self_cls_name, | |
"0.26.0", | |
( | |
f"Make sure use {self_cls_name[4:]} instead by setting" | |
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" | |
" `LoraLoaderMixin.load_lora_weights`" | |
), | |
) | |
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) | |
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) | |
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) | |
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) | |
attn._modules.pop("processor") | |
attn.processor = AttnProcessor() | |
return attn.processor(attn, hidden_states, *args, **kwargs) | |
class LoRAAttnProcessor2_0(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int, | |
cross_attention_dim: Optional[int] = None, | |
rank: int = 4, | |
network_alpha: Optional[int] = None, | |
**kwargs, | |
): | |
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`." | |
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False) | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.rank = rank | |
q_rank = kwargs.pop("q_rank", None) | |
q_hidden_size = kwargs.pop("q_hidden_size", None) | |
q_rank = q_rank if q_rank is not None else rank | |
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size | |
v_rank = kwargs.pop("v_rank", None) | |
v_hidden_size = kwargs.pop("v_hidden_size", None) | |
v_rank = v_rank if v_rank is not None else rank | |
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size | |
out_rank = kwargs.pop("out_rank", None) | |
out_hidden_size = kwargs.pop("out_hidden_size", None) | |
out_rank = out_rank if out_rank is not None else rank | |
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size | |
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) | |
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | |
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) | |
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) | |
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: | |
self_cls_name = self.__class__.__name__ | |
deprecate( | |
self_cls_name, | |
"0.26.0", | |
( | |
f"Make sure use {self_cls_name[4:]} instead by setting" | |
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" | |
" `LoraLoaderMixin.load_lora_weights`" | |
), | |
) | |
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) | |
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) | |
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) | |
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) | |
attn._modules.pop("processor") | |
attn.processor = AttnProcessor2_0() | |
return attn.processor(attn, hidden_states, *args, **kwargs) | |
class LoRAXFormersAttnProcessor(nn.Module): | |
r""" | |
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. | |
Args: | |
hidden_size (`int`, *optional*): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`, *optional*): | |
The number of channels in the `encoder_hidden_states`. | |
rank (`int`, defaults to 4): | |
The dimension of the LoRA update matrices. | |
attention_op (`Callable`, *optional*, defaults to `None`): | |
The base | |
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to | |
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best | |
operator. | |
network_alpha (`int`, *optional*): | |
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. | |
kwargs (`dict`): | |
Additional keyword arguments to pass to the `LoRALinearLayer` layers. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
cross_attention_dim: int, | |
rank: int = 4, | |
attention_op: Optional[Callable] = None, | |
network_alpha: Optional[int] = None, | |
**kwargs, | |
): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.rank = rank | |
self.attention_op = attention_op | |
q_rank = kwargs.pop("q_rank", None) | |
q_hidden_size = kwargs.pop("q_hidden_size", None) | |
q_rank = q_rank if q_rank is not None else rank | |
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size | |
v_rank = kwargs.pop("v_rank", None) | |
v_hidden_size = kwargs.pop("v_hidden_size", None) | |
v_rank = v_rank if v_rank is not None else rank | |
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size | |
out_rank = kwargs.pop("out_rank", None) | |
out_hidden_size = kwargs.pop("out_hidden_size", None) | |
out_rank = out_rank if out_rank is not None else rank | |
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size | |
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) | |
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | |
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) | |
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) | |
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: | |
self_cls_name = self.__class__.__name__ | |
deprecate( | |
self_cls_name, | |
"0.26.0", | |
( | |
f"Make sure use {self_cls_name[4:]} instead by setting" | |
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using" | |
" `LoraLoaderMixin.load_lora_weights`" | |
), | |
) | |
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) | |
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) | |
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) | |
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) | |
attn._modules.pop("processor") | |
attn.processor = XFormersAttnProcessor() | |
return attn.processor(attn, hidden_states, *args, **kwargs) | |
class LoRAAttnAddedKVProcessor(nn.Module): | |
r""" | |
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text | |
encoder. | |
Args: | |
hidden_size (`int`, *optional*): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`, *optional*, defaults to `None`): | |
The number of channels in the `encoder_hidden_states`. | |
rank (`int`, defaults to 4): | |
The dimension of the LoRA update matrices. | |
network_alpha (`int`, *optional*): | |
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. | |
kwargs (`dict`): | |
Additional keyword arguments to pass to the `LoRALinearLayer` layers. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
cross_attention_dim: Optional[int] = None, | |
rank: int = 4, | |
network_alpha: Optional[int] = None, | |
): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.rank = rank | |
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | |
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | |
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | |
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | |
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | |
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | |
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: | |
self_cls_name = self.__class__.__name__ | |
deprecate( | |
self_cls_name, | |
"0.26.0", | |
( | |
f"Make sure use {self_cls_name[4:]} instead by setting" | |
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using" | |
" `LoraLoaderMixin.load_lora_weights`" | |
), | |
) | |
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) | |
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) | |
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) | |
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) | |
attn._modules.pop("processor") | |
attn.processor = AttnAddedKVProcessor() | |
return attn.processor(attn, hidden_states, *args, **kwargs) | |
class IPAdapterAttnProcessor(nn.Module): | |
r""" | |
Attention processor for Multiple IP-Adapater. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): | |
The context length of the image features. | |
scale (`float` or List[`float`], defaults to 1.0): | |
the weight scale of image prompt. | |
""" | |
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
if not isinstance(num_tokens, (tuple, list)): | |
num_tokens = [num_tokens] | |
self.num_tokens = num_tokens | |
if not isinstance(scale, list): | |
scale = [scale] * len(num_tokens) | |
if len(scale) != len(num_tokens): | |
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") | |
self.scale = scale | |
self.to_k_ip = nn.ModuleList( | |
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] | |
) | |
self.to_v_ip = nn.ModuleList( | |
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
): | |
residual = hidden_states | |
# separate ip_hidden_states from encoder_hidden_states | |
if encoder_hidden_states is not None: | |
if isinstance(encoder_hidden_states, tuple): | |
encoder_hidden_states, ip_hidden_states = encoder_hidden_states | |
else: | |
deprecation_message = ( | |
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." | |
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." | |
) | |
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] | |
encoder_hidden_states, ip_hidden_states = ( | |
encoder_hidden_states[:, :end_pos, :], | |
[encoder_hidden_states[:, end_pos:, :]], | |
) | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
if ip_adapter_masks is not None: | |
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: | |
raise ValueError( | |
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." | |
" Please use `IPAdapterMaskProcessor` to preprocess your mask" | |
) | |
if len(ip_adapter_masks) != len(self.scale): | |
raise ValueError( | |
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" | |
) | |
else: | |
ip_adapter_masks = [None] * len(self.scale) | |
# for ip-adapter | |
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( | |
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks | |
): | |
ip_key = to_k_ip(current_ip_hidden_states) | |
ip_value = to_v_ip(current_ip_hidden_states) | |
ip_key = attn.head_to_batch_dim(ip_key) | |
ip_value = attn.head_to_batch_dim(ip_value) | |
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | |
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | |
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) | |
if mask is not None: | |
mask_downsample = IPAdapterMaskProcessor.downsample( | |
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] | |
) | |
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | |
current_ip_hidden_states = current_ip_hidden_states * mask_downsample | |
hidden_states = hidden_states + scale * current_ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class IPAdapterAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): | |
The context length of the image features. | |
scale (`float` or `List[float]`, defaults to 1.0): | |
the weight scale of image prompt. | |
""" | |
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
if not isinstance(num_tokens, (tuple, list)): | |
num_tokens = [num_tokens] | |
self.num_tokens = num_tokens | |
if not isinstance(scale, list): | |
scale = [scale] * len(num_tokens) | |
if len(scale) != len(num_tokens): | |
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") | |
self.scale = scale | |
self.to_k_ip = nn.ModuleList( | |
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] | |
) | |
self.to_v_ip = nn.ModuleList( | |
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
ip_adapter_masks: Optional[torch.FloatTensor] = None, | |
): | |
residual = hidden_states | |
# separate ip_hidden_states from encoder_hidden_states | |
if encoder_hidden_states is not None: | |
if isinstance(encoder_hidden_states, tuple): | |
encoder_hidden_states, ip_hidden_states = encoder_hidden_states | |
else: | |
deprecation_message = ( | |
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." | |
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." | |
) | |
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] | |
encoder_hidden_states, ip_hidden_states = ( | |
encoder_hidden_states[:, :end_pos, :], | |
[encoder_hidden_states[:, end_pos:, :]], | |
) | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
if ip_adapter_masks is not None: | |
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: | |
raise ValueError( | |
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." | |
" Please use `IPAdapterMaskProcessor` to preprocess your mask" | |
) | |
if len(ip_adapter_masks) != len(self.scale): | |
raise ValueError( | |
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" | |
) | |
else: | |
ip_adapter_masks = [None] * len(self.scale) | |
# for ip-adapter | |
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( | |
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks | |
): | |
ip_key = to_k_ip(current_ip_hidden_states) | |
ip_value = to_v_ip(current_ip_hidden_states) | |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
current_ip_hidden_states = F.scaled_dot_product_attention( | |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | |
if mask is not None: | |
mask_downsample = IPAdapterMaskProcessor.downsample( | |
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] | |
) | |
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | |
current_ip_hidden_states = current_ip_hidden_states * mask_downsample | |
hidden_states = hidden_states + scale * current_ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
LORA_ATTENTION_PROCESSORS = ( | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
LoRAAttnAddedKVProcessor, | |
) | |
ADDED_KV_ATTENTION_PROCESSORS = ( | |
AttnAddedKVProcessor, | |
SlicedAttnAddedKVProcessor, | |
AttnAddedKVProcessor2_0, | |
XFormersAttnAddedKVProcessor, | |
LoRAAttnAddedKVProcessor, | |
) | |
CROSS_ATTENTION_PROCESSORS = ( | |
AttnProcessor, | |
AttnProcessor2_0, | |
XFormersAttnProcessor, | |
SlicedAttnProcessor, | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
IPAdapterAttnProcessor, | |
IPAdapterAttnProcessor2_0, | |
) | |
AttentionProcessor = Union[ | |
AttnProcessor, | |
AttnProcessor2_0, | |
FusedAttnProcessor2_0, | |
XFormersAttnProcessor, | |
SlicedAttnProcessor, | |
AttnAddedKVProcessor, | |
SlicedAttnAddedKVProcessor, | |
AttnAddedKVProcessor2_0, | |
XFormersAttnAddedKVProcessor, | |
CustomDiffusionAttnProcessor, | |
CustomDiffusionXFormersAttnProcessor, | |
CustomDiffusionAttnProcessor2_0, | |
# deprecated | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
LoRAAttnAddedKVProcessor, | |
] | |
# coding=utf-8 | |
# Copyright 2024 HuggingFace Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numbers | |
from typing import Dict, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils import is_torch_version | |
from .activations import get_activation | |
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings | |
class AdaLayerNorm(nn.Module): | |
r""" | |
Norm layer modified to incorporate timestep embeddings. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`): The size of the embeddings dictionary. | |
""" | |
def __init__(self, embedding_dim: int, num_embeddings: int): | |
super().__init__() | |
self.emb = nn.Embedding(num_embeddings, embedding_dim) | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, embedding_dim * 2) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) | |
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: | |
emb = self.linear(self.silu(self.emb(timestep))) | |
scale, shift = torch.chunk(emb, 2) | |
x = self.norm(x) * (1 + scale) + shift | |
return x | |
class AdaLayerNormZero(nn.Module): | |
r""" | |
Norm layer adaptive layer norm zero (adaLN-Zero). | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`): The size of the embeddings dictionary. | |
""" | |
def __init__(self, embedding_dim: int, num_embeddings: int): | |
super().__init__() | |
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward( | |
self, | |
x: torch.Tensor, | |
timestep: torch.Tensor, | |
class_labels: torch.LongTensor, | |
hidden_dtype: Optional[torch.dtype] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) | |
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class AdaLayerNormSingle(nn.Module): | |
r""" | |
Norm layer adaptive layer norm single (adaLN-single). | |
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
use_additional_conditions (`bool`): To use additional conditions for normalization or not. | |
""" | |
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): | |
super().__init__() | |
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( | |
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions | |
) | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) | |
def forward( | |
self, | |
timestep: torch.Tensor, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
batch_size: Optional[int] = None, | |
hidden_dtype: Optional[torch.dtype] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
# No modulation happening here. | |
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) | |
return self.linear(self.silu(embedded_timestep)), embedded_timestep | |
class AdaGroupNorm(nn.Module): | |
r""" | |
GroupNorm layer modified to incorporate timestep embeddings. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`): The size of the embeddings dictionary. | |
num_groups (`int`): The number of groups to separate the channels into. | |
act_fn (`str`, *optional*, defaults to `None`): The activation function to use. | |
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. | |
""" | |
def __init__( | |
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 | |
): | |
super().__init__() | |
self.num_groups = num_groups | |
self.eps = eps | |
if act_fn is None: | |
self.act = None | |
else: | |
self.act = get_activation(act_fn) | |
self.linear = nn.Linear(embedding_dim, out_dim * 2) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
if self.act: | |
emb = self.act(emb) | |
emb = self.linear(emb) | |
emb = emb[:, :, None, None] | |
scale, shift = emb.chunk(2, dim=1) | |
x = F.group_norm(x, self.num_groups, eps=self.eps) | |
x = x * (1 + scale) + shift | |
return x | |
class AdaLayerNormContinuous(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
conditioning_embedding_dim: int, | |
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters | |
# because the output is immediately scaled and shifted by the projected conditioning embeddings. | |
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. | |
# However, this is how it was implemented in the original code, and it's rather likely you should | |
# set `elementwise_affine` to False. | |
elementwise_affine=True, | |
eps=1e-5, | |
bias=True, | |
norm_type="layer_norm", | |
): | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | |
if norm_type == "layer_norm": | |
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
elif norm_type == "rms_norm": | |
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) | |
else: | |
raise ValueError(f"unknown norm_type {norm_type}") | |
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: | |
emb = self.linear(self.silu(conditioning_embedding)) | |
scale, shift = torch.chunk(emb, 2, dim=1) | |
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] | |
return x | |
if is_torch_version(">=", "2.1.0"): | |
LayerNorm = nn.LayerNorm | |
else: | |
# Has optional bias parameter compared to torch layer norm | |
# TODO: replace with torch layernorm once min required torch version >= 2.1 | |
class LayerNorm(nn.Module): | |
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): | |
super().__init__() | |
self.eps = eps | |
if isinstance(dim, numbers.Integral): | |
dim = (dim,) | |
self.dim = torch.Size(dim) | |
if elementwise_affine: | |
self.weight = nn.Parameter(torch.ones(dim)) | |
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None | |
else: | |
self.weight = None | |
self.bias = None | |
def forward(self, input): | |
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, eps: float, elementwise_affine: bool = True): | |
super().__init__() | |
self.eps = eps | |
if isinstance(dim, numbers.Integral): | |
dim = (dim,) | |
self.dim = torch.Size(dim) | |
if elementwise_affine: | |
self.weight = nn.Parameter(torch.ones(dim)) | |
else: | |
self.weight = None | |
def forward(self, hidden_states): | |
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
if self.weight is not None: | |
# convert into half-precision if necessary | |
if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
hidden_states = hidden_states.to(self.weight.dtype) | |
hidden_states = hidden_states * self.weight | |
else: | |
hidden_states = hidden_states.to(input_dtype) | |
return hidden_states | |
class GlobalResponseNorm(nn.Module): | |
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 | |
def __init__(self, dim): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
def forward(self, x): | |
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) | |
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) | |
return self.gamma * (x * nx) + self.beta + x | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils import USE_PEFT_BACKEND | |
from .lora import LoRACompatibleConv | |
from .normalization import RMSNorm | |
from .upsampling import upfirdn2d_native | |
class Downsample1D(nn.Module): | |
"""A 1D downsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
padding (`int`, default `1`): | |
padding for the convolution. | |
name (`str`, default `conv`): | |
name of the downsampling 1D layer. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
out_channels: Optional[int] = None, | |
padding: int = 1, | |
name: str = "conv", | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.padding = padding | |
stride = 2 | |
self.name = name | |
if use_conv: | |
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) | |
else: | |
assert self.channels == self.out_channels | |
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
assert inputs.shape[1] == self.channels | |
return self.conv(inputs) | |
class Downsample2D(nn.Module): | |
"""A 2D downsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
padding (`int`, default `1`): | |
padding for the convolution. | |
name (`str`, default `conv`): | |
name of the downsampling 2D layer. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
out_channels: Optional[int] = None, | |
padding: int = 1, | |
name: str = "conv", | |
kernel_size=3, | |
norm_type=None, | |
eps=None, | |
elementwise_affine=None, | |
bias=True, | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.padding = padding | |
stride = 2 | |
self.name = name | |
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | |
if norm_type == "ln_norm": | |
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) | |
elif norm_type == "rms_norm": | |
self.norm = RMSNorm(channels, eps, elementwise_affine) | |
elif norm_type is None: | |
self.norm = None | |
else: | |
raise ValueError(f"unknown norm_type: {norm_type}") | |
if use_conv: | |
conv = conv_cls( | |
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias | |
) | |
else: | |
assert self.channels == self.out_channels | |
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) | |
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
if name == "conv": | |
self.Conv2d_0 = conv | |
self.conv = conv | |
elif name == "Conv2d_0": | |
self.conv = conv | |
else: | |
self.conv = conv | |
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: | |
assert hidden_states.shape[1] == self.channels | |
if self.norm is not None: | |
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
if self.use_conv and self.padding == 0: | |
pad = (0, 1, 0, 1) | |
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) | |
assert hidden_states.shape[1] == self.channels | |
if not USE_PEFT_BACKEND: | |
if isinstance(self.conv, LoRACompatibleConv): | |
hidden_states = self.conv(hidden_states, scale) | |
else: | |
hidden_states = self.conv(hidden_states) | |
else: | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class FirDownsample2D(nn.Module): | |
"""A 2D FIR downsampling layer with an optional convolution. | |
Parameters: | |
channels (`int`): | |
number of channels in the inputs and outputs. | |
use_conv (`bool`, default `False`): | |
option to use a convolution. | |
out_channels (`int`, optional): | |
number of output channels. Defaults to `channels`. | |
fir_kernel (`tuple`, default `(1, 3, 3, 1)`): | |
kernel for the FIR filter. | |
""" | |
def __init__( | |
self, | |
channels: Optional[int] = None, | |
out_channels: Optional[int] = None, | |
use_conv: bool = False, | |
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), | |
): | |
super().__init__() | |
out_channels = out_channels if out_channels else channels | |
if use_conv: | |
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.fir_kernel = fir_kernel | |
self.use_conv = use_conv | |
self.out_channels = out_channels | |
def _downsample_2d( | |
self, | |
hidden_states: torch.FloatTensor, | |
weight: Optional[torch.FloatTensor] = None, | |
kernel: Optional[torch.FloatTensor] = None, | |
factor: int = 2, | |
gain: float = 1, | |
) -> torch.FloatTensor: | |
"""Fused `Conv2d()` followed by `downsample_2d()`. | |
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more | |
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of | |
arbitrary order. | |
Args: | |
hidden_states (`torch.FloatTensor`): | |
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. | |
weight (`torch.FloatTensor`, *optional*): | |
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be | |
performed by `inChannels = x.shape[0] // numGroups`. | |
kernel (`torch.FloatTensor`, *optional*): | |
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which | |
corresponds to average pooling. | |
factor (`int`, *optional*, default to `2`): | |
Integer downsampling factor. | |
gain (`float`, *optional*, default to `1.0`): | |
Scaling factor for signal magnitude. | |
Returns: | |
output (`torch.FloatTensor`): | |
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same | |
datatype as `x`. | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
if kernel is None: | |
kernel = [1] * factor | |
# setup kernel | |
kernel = torch.tensor(kernel, dtype=torch.float32) | |
if kernel.ndim == 1: | |
kernel = torch.outer(kernel, kernel) | |
kernel /= torch.sum(kernel) | |
kernel = kernel * gain | |
if self.use_conv: | |
_, _, convH, convW = weight.shape | |
pad_value = (kernel.shape[0] - factor) + (convW - 1) | |
stride_value = [factor, factor] | |
upfirdn_input = upfirdn2d_native( | |
hidden_states, | |
torch.tensor(kernel, device=hidden_states.device), | |
pad=((pad_value + 1) // 2, pad_value // 2), | |
) | |
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) | |
else: | |
pad_value = kernel.shape[0] - factor | |
output = upfirdn2d_native( | |
hidden_states, | |
torch.tensor(kernel, device=hidden_states.device), | |
down=factor, | |
pad=((pad_value + 1) // 2, pad_value // 2), | |
) | |
return output | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
if self.use_conv: | |
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) | |
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) | |
else: | |
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) | |
return hidden_states | |
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead | |
class KDownsample2D(nn.Module): | |
r"""A 2D K-downsampling layer. | |
Parameters: | |
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. | |
""" | |
def __init__(self, pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) | |
self.pad = kernel_1d.shape[1] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) | |
weight = inputs.new_zeros( | |
[ | |
inputs.shape[1], | |
inputs.shape[1], | |
self.kernel.shape[0], | |
self.kernel.shape[1], | |
] | |
) | |
indices = torch.arange(inputs.shape[1], device=inputs.device) | |
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) | |
weight[indices, indices] = kernel | |
return F.conv2d(inputs, weight, stride=2) | |
def downsample_2d( | |
hidden_states: torch.FloatTensor, | |
kernel: Optional[torch.FloatTensor] = None, | |
factor: int = 2, | |
gain: float = 1, | |
) -> torch.FloatTensor: | |
r"""Downsample2D a batch of 2D images with the given filter. | |
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the | |
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the | |
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its | |
shape is a multiple of the downsampling factor. | |
Args: | |
hidden_states (`torch.FloatTensor`) | |
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. | |
kernel (`torch.FloatTensor`, *optional*): | |
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which | |
corresponds to average pooling. | |
factor (`int`, *optional*, default to `2`): | |
Integer downsampling factor. | |
gain (`float`, *optional*, default to `1.0`): | |
Scaling factor for signal magnitude. | |
Returns: | |
output (`torch.FloatTensor`): | |
Tensor of the shape `[N, C, H // factor, W // factor]` | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
if kernel is None: | |
kernel = [1] * factor | |
kernel = torch.tensor(kernel, dtype=torch.float32) | |
if kernel.ndim == 1: | |
kernel = torch.outer(kernel, kernel) | |
kernel /= torch.sum(kernel) | |
kernel = kernel * gain | |
pad_value = kernel.shape[0] - factor | |
output = upfirdn2d_native( | |
hidden_states, | |
kernel.to(device=hidden_states.device), | |
down=factor, | |
pad=((pad_value + 1) // 2, pad_value // 2), | |
) | |
return output | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
import itertools | |
import os | |
import re | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Any, Callable, List, Optional, Tuple, Union | |
import safetensors | |
import torch | |
from huggingface_hub import create_repo | |
from huggingface_hub.utils import validate_hf_hub_args | |
from torch import Tensor, nn | |
from .. import __version__ | |
from ..utils import ( | |
CONFIG_NAME, | |
FLAX_WEIGHTS_NAME, | |
SAFETENSORS_FILE_EXTENSION, | |
SAFETENSORS_WEIGHTS_NAME, | |
WEIGHTS_NAME, | |
_add_variant, | |
_get_model_file, | |
deprecate, | |
is_accelerate_available, | |
is_torch_version, | |
logging, | |
) | |
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card | |
logger = logging.get_logger(__name__) | |
if is_torch_version(">=", "1.9.0"): | |
_LOW_CPU_MEM_USAGE_DEFAULT = True | |
else: | |
_LOW_CPU_MEM_USAGE_DEFAULT = False | |
if is_accelerate_available(): | |
import accelerate | |
from accelerate.utils import set_module_tensor_to_device | |
from accelerate.utils.versions import is_torch_version | |
def get_parameter_device(parameter: torch.nn.Module) -> torch.device: | |
try: | |
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) | |
return next(parameters_and_buffers).device | |
except StopIteration: | |
# For torch.nn.DataParallel compatibility in PyTorch 1.5 | |
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: | |
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
return tuples | |
gen = parameter._named_members(get_members_fn=find_tensor_attributes) | |
first_tuple = next(gen) | |
return first_tuple[1].device | |
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: | |
try: | |
params = tuple(parameter.parameters()) | |
if len(params) > 0: | |
return params[0].dtype | |
buffers = tuple(parameter.buffers()) | |
if len(buffers) > 0: | |
return buffers[0].dtype | |
except StopIteration: | |
# For torch.nn.DataParallel compatibility in PyTorch 1.5 | |
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: | |
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
return tuples | |
gen = parameter._named_members(get_members_fn=find_tensor_attributes) | |
first_tuple = next(gen) | |
return first_tuple[1].dtype | |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): | |
""" | |
Reads a checkpoint file, returning properly formatted errors if they arise. | |
""" | |
try: | |
file_extension = os.path.basename(checkpoint_file).split(".")[-1] | |
if file_extension == SAFETENSORS_FILE_EXTENSION: | |
return safetensors.torch.load_file(checkpoint_file, device="cpu") | |
else: | |
return torch.load(checkpoint_file, map_location="cpu") | |
except Exception as e: | |
try: | |
with open(checkpoint_file) as f: | |
if f.read().startswith("version"): | |
raise OSError( | |
"You seem to have cloned a repository without having git-lfs installed. Please install " | |
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " | |
"you cloned." | |
) | |
else: | |
raise ValueError( | |
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " | |
"model. Make sure you have saved the model properly." | |
) from e | |
except (UnicodeDecodeError, ValueError): | |
raise OSError( | |
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " | |
) | |
def load_model_dict_into_meta( | |
model, | |
state_dict: OrderedDict, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[Union[str, torch.dtype]] = None, | |
model_name_or_path: Optional[str] = None, | |
) -> List[str]: | |
device = device or torch.device("cpu") | |
dtype = dtype or torch.float32 | |
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) | |
unexpected_keys = [] | |
empty_state_dict = model.state_dict() | |
for param_name, param in state_dict.items(): | |
if param_name not in empty_state_dict: | |
unexpected_keys.append(param_name) | |
continue | |
if empty_state_dict[param_name].shape != param.shape: | |
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | |
raise ValueError( | |
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | |
) | |
if accepts_dtype: | |
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | |
else: | |
set_module_tensor_to_device(model, param_name, device, value=param) | |
return unexpected_keys | |
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: | |
# Convert old format to new format if needed from a PyTorch state_dict | |
# copy state_dict so _load_from_state_dict can modify it | |
state_dict = state_dict.copy() | |
error_msgs = [] | |
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
# so we need to apply the function recursively. | |
def load(module: torch.nn.Module, prefix: str = ""): | |
args = (state_dict, prefix, {}, True, [], [], error_msgs) | |
module._load_from_state_dict(*args) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + ".") | |
load(model_to_load) | |
return error_msgs | |
class ModelMixin(torch.nn.Module, PushToHubMixin): | |
r""" | |
Base class for all models. | |
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and | |
saving models. | |
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. | |
""" | |
config_name = CONFIG_NAME | |
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] | |
_supports_gradient_checkpointing = False | |
_keys_to_ignore_on_load_unexpected = None | |
def __init__(self): | |
super().__init__() | |
def __getattr__(self, name: str) -> Any: | |
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing | |
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite | |
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': | |
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module | |
""" | |
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) | |
is_attribute = name in self.__dict__ | |
if is_in_config and not is_attribute: | |
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." | |
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) | |
return self._internal_dict[name] | |
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module | |
return super().__getattr__(name) | |
@property | |
def is_gradient_checkpointing(self) -> bool: | |
""" | |
Whether gradient checkpointing is activated for this model or not. | |
""" | |
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) | |
def enable_gradient_checkpointing(self) -> None: | |
""" | |
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or | |
*checkpoint activations* in other frameworks). | |
""" | |
if not self._supports_gradient_checkpointing: | |
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") | |
self.apply(partial(self._set_gradient_checkpointing, value=True)) | |
def disable_gradient_checkpointing(self) -> None: | |
""" | |
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or | |
*checkpoint activations* in other frameworks). | |
""" | |
if self._supports_gradient_checkpointing: | |
self.apply(partial(self._set_gradient_checkpointing, value=False)) | |
def set_use_memory_efficient_attention_xformers( | |
self, valid: bool, attention_op: Optional[Callable] = None | |
) -> None: | |
# Recursively walk through all the children. | |
# Any children which exposes the set_use_memory_efficient_attention_xformers method | |
# gets the message | |
def fn_recursive_set_mem_eff(module: torch.nn.Module): | |
if hasattr(module, "set_use_memory_efficient_attention_xformers"): | |
module.set_use_memory_efficient_attention_xformers(valid, attention_op) | |
for child in module.children(): | |
fn_recursive_set_mem_eff(child) | |
for module in self.children(): | |
if isinstance(module, torch.nn.Module): | |
fn_recursive_set_mem_eff(module) | |
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: | |
r""" | |
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). | |
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during | |
inference. Speed up during training is not guaranteed. | |
<Tip warning={true}> | |
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes | |
precedent. | |
</Tip> | |
Parameters: | |
attention_op (`Callable`, *optional*): | |
Override the default `None` operator for use as `op` argument to the | |
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) | |
function of xFormers. | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import UNet2DConditionModel | |
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
>>> model = UNet2DConditionModel.from_pretrained( | |
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 | |
... ) | |
>>> model = model.to("cuda") | |
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) | |
``` | |
""" | |
self.set_use_memory_efficient_attention_xformers(True, attention_op) | |
def disable_xformers_memory_efficient_attention(self) -> None: | |
r""" | |
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). | |
""" | |
self.set_use_memory_efficient_attention_xformers(False) | |
def save_pretrained( | |
self, | |
save_directory: Union[str, os.PathLike], | |
is_main_process: bool = True, | |
save_function: Optional[Callable] = None, | |
safe_serialization: bool = True, | |
variant: Optional[str] = None, | |
push_to_hub: bool = False, | |
**kwargs, | |
): | |
""" | |
Save a model and its configuration file to a directory so that it can be reloaded using the | |
[`~models.ModelMixin.from_pretrained`] class method. | |
Arguments: | |
save_directory (`str` or `os.PathLike`): | |
Directory to save a model and its configuration file to. Will be created if it doesn't exist. | |
is_main_process (`bool`, *optional*, defaults to `True`): | |
Whether the process calling this is the main process or not. Useful during distributed training and you | |
need to call this function on all processes. In this case, set `is_main_process=True` only on the main | |
process to avoid race conditions. | |
save_function (`Callable`): | |
The function to use to save the state dictionary. Useful during distributed training when you need to | |
replace `torch.save` with another method. Can be configured with the environment variable | |
`DIFFUSERS_SAVE_MODE`. | |
safe_serialization (`bool`, *optional*, defaults to `True`): | |
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | |
variant (`str`, *optional*): | |
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the | |
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your | |
namespace). | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | |
""" | |
if os.path.isfile(save_directory): | |
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
return | |
os.makedirs(save_directory, exist_ok=True) | |
if push_to_hub: | |
commit_message = kwargs.pop("commit_message", None) | |
private = kwargs.pop("private", False) | |
create_pr = kwargs.pop("create_pr", False) | |
token = kwargs.pop("token", None) | |
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | |
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id | |
# Only save the model itself if we are using distributed training | |
model_to_save = self | |
# Attach architecture to the config | |
# Save the config | |
if is_main_process: | |
model_to_save.save_config(save_directory) | |
# Save the model | |
state_dict = model_to_save.state_dict() | |
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME | |
weights_name = _add_variant(weights_name, variant) | |
# Save the model | |
if safe_serialization: | |
safetensors.torch.save_file( | |
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} | |
) | |
else: | |
torch.save(state_dict, os.path.join(save_directory, weights_name)) | |
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") | |
if push_to_hub: | |
# Create a new empty model card and eventually tag it | |
model_card = load_or_create_model_card(repo_id, token=token) | |
model_card = populate_model_card(model_card) | |
model_card.save(os.path.join(save_directory, "README.md")) | |
self._upload_folder( | |
save_directory, | |
repo_id, | |
token=token, | |
commit_message=commit_message, | |
create_pr=create_pr, | |
) | |
@classmethod | |
@validate_hf_hub_args | |
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | |
r""" | |
Instantiate a pretrained PyTorch model from a pretrained model configuration. | |
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To | |
train the model, set it back in training mode with `model.train()`. | |
Parameters: | |
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): | |
Can be either: | |
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
with [`~ModelMixin.save_pretrained`]. | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
is not used. | |
torch_dtype (`str` or `torch.dtype`, *optional*): | |
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the | |
dtype is automatically derived from the model's weights. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | |
incompletely downloaded files are deleted. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
output_loading_info (`bool`, *optional*, defaults to `False`): | |
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. | |
local_files_only(`bool`, *optional*, defaults to `False`): | |
Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
won't be downloaded from the Hub. | |
token (`str` or *bool*, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
`diffusers-cli login` (stored in `~/.huggingface`) is used. | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
allowed by Git. | |
from_flax (`bool`, *optional*, defaults to `False`): | |
Load the model weights from a Flax checkpoint save file. | |
subfolder (`str`, *optional*, defaults to `""`): | |
The subfolder location of a model file within a larger model repository on the Hub or locally. | |
mirror (`str`, *optional*): | |
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not | |
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more | |
information. | |
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): | |
A map that specifies where each submodule should go. It doesn't need to be defined for each | |
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the | |
same device. | |
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For | |
more information about each option see [designing a device | |
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). | |
max_memory (`Dict`, *optional*): | |
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for | |
each GPU and the available CPU RAM if unset. | |
offload_folder (`str` or `os.PathLike`, *optional*): | |
The path to offload weights if `device_map` contains the value `"disk"`. | |
offload_state_dict (`bool`, *optional*): | |
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if | |
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` | |
when there is some disk offload. | |
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): | |
Speed up model loading only loading the pretrained weights and not initializing the weights. This also | |
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. | |
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this | |
argument to `True` will raise an error. | |
variant (`str`, *optional*): | |
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when | |
loading `from_flax`. | |
use_safetensors (`bool`, *optional*, defaults to `None`): | |
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the | |
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` | |
weights. If set to `False`, `safetensors` weights are not loaded. | |
<Tip> | |
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with | |
`huggingface-cli login`. You can also activate the special | |
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a | |
firewalled environment. | |
</Tip> | |
Example: | |
```py | |
from diffusers import UNet2DConditionModel | |
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") | |
``` | |
If you get the error message below, you need to finetune the weights for your downstream task: | |
```bash | |
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: | |
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated | |
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. | |
``` | |
""" | |
cache_dir = kwargs.pop("cache_dir", None) | |
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) | |
force_download = kwargs.pop("force_download", False) | |
from_flax = kwargs.pop("from_flax", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
output_loading_info = kwargs.pop("output_loading_info", False) | |
local_files_only = kwargs.pop("local_files_only", None) | |
token = kwargs.pop("token", None) | |
revision = kwargs.pop("revision", None) | |
torch_dtype = kwargs.pop("torch_dtype", None) | |
subfolder = kwargs.pop("subfolder", None) | |
device_map = kwargs.pop("device_map", None) | |
max_memory = kwargs.pop("max_memory", None) | |
offload_folder = kwargs.pop("offload_folder", None) | |
offload_state_dict = kwargs.pop("offload_state_dict", False) | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | |
variant = kwargs.pop("variant", None) | |
use_safetensors = kwargs.pop("use_safetensors", None) | |
allow_pickle = False | |
if use_safetensors is None: | |
use_safetensors = True | |
allow_pickle = True | |
if low_cpu_mem_usage and not is_accelerate_available(): | |
low_cpu_mem_usage = False | |
logger.warning( | |
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" | |
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" | |
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" | |
" install accelerate\n```\n." | |
) | |
if device_map is not None and not is_accelerate_available(): | |
raise NotImplementedError( | |
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" | |
" `device_map=None`. You can install accelerate with `pip install accelerate`." | |
) | |
# Check if we can handle device_map and dispatching the weights | |
if device_map is not None and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `device_map=None`." | |
) | |
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `low_cpu_mem_usage=False`." | |
) | |
if low_cpu_mem_usage is False and device_map is not None: | |
raise ValueError( | |
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" | |
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." | |
) | |
# Load config if we don't provide a configuration | |
config_path = pretrained_model_name_or_path | |
user_agent = { | |
"diffusers": __version__, | |
"file_type": "model", | |
"framework": "pytorch", | |
} | |
# load config | |
config, unused_kwargs, commit_hash = cls.load_config( | |
config_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
return_commit_hash=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
device_map=device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
user_agent=user_agent, | |
**kwargs, | |
) | |
# load model | |
model_file = None | |
if from_flax: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=FLAX_WEIGHTS_NAME, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
model = cls.from_config(config, **unused_kwargs) | |
# Convert the weights | |
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model | |
model = load_flax_checkpoint_in_pytorch_model(model, model_file) | |
else: | |
if use_safetensors: | |
try: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
except IOError as e: | |
if not allow_pickle: | |
raise e | |
pass | |
if model_file is None: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=_add_variant(WEIGHTS_NAME, variant), | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
if low_cpu_mem_usage: | |
# Instantiate model with empty weights | |
with accelerate.init_empty_weights(): | |
model = cls.from_config(config, **unused_kwargs) | |
# if device_map is None, load the state dict and move the params from meta device to the cpu | |
if device_map is None: | |
param_device = "cpu" | |
state_dict = load_state_dict(model_file, variant=variant) | |
model._convert_deprecated_attention_blocks(state_dict) | |
# move the params from meta device to cpu | |
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) | |
if len(missing_keys) > 0: | |
raise ValueError( | |
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" | |
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" | |
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" | |
" those weights or else make sure your checkpoint file is correct." | |
) | |
unexpected_keys = load_model_dict_into_meta( | |
model, | |
state_dict, | |
device=param_device, | |
dtype=torch_dtype, | |
model_name_or_path=pretrained_model_name_or_path, | |
) | |
if cls._keys_to_ignore_on_load_unexpected is not None: | |
for pat in cls._keys_to_ignore_on_load_unexpected: | |
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | |
if len(unexpected_keys) > 0: | |
logger.warn( | |
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" | |
) | |
else: # else let accelerate handle loading and dispatching. | |
# Load weights and dispatch according to the device_map | |
# by default the device_map is None and the weights are loaded on the CPU | |
try: | |
accelerate.load_checkpoint_and_dispatch( | |
model, | |
model_file, | |
device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
dtype=torch_dtype, | |
) | |
except AttributeError as e: | |
# When using accelerate loading, we do not have the ability to load the state | |
# dict and rename the weight names manually. Additionally, accelerate skips | |
# torch loading conventions and directly writes into `module.{_buffers, _parameters}` | |
# (which look like they should be private variables?), so we can't use the standard hooks | |
# to rename parameters on load. We need to mimic the original weight names so the correct | |
# attributes are available. After we have loaded the weights, we convert the deprecated | |
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert | |
# the weights so we don't have to do this again. | |
if "'Attention' object has no attribute" in str(e): | |
logger.warn( | |
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" | |
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block" | |
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," | |
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," | |
" please also re-upload it or open a PR on the original repository." | |
) | |
model._temp_convert_self_to_deprecated_attention_blocks() | |
accelerate.load_checkpoint_and_dispatch( | |
model, | |
model_file, | |
device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
dtype=torch_dtype, | |
) | |
model._undo_temp_convert_self_to_deprecated_attention_blocks() | |
else: | |
raise e | |
loading_info = { | |
"missing_keys": [], | |
"unexpected_keys": [], | |
"mismatched_keys": [], | |
"error_msgs": [], | |
} | |
else: | |
model = cls.from_config(config, **unused_kwargs) | |
state_dict = load_state_dict(model_file, variant=variant) | |
model._convert_deprecated_attention_blocks(state_dict) | |
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( | |
model, | |
state_dict, | |
model_file, | |
pretrained_model_name_or_path, | |
ignore_mismatched_sizes=ignore_mismatched_sizes, | |
) | |
loading_info = { | |
"missing_keys": missing_keys, | |
"unexpected_keys": unexpected_keys, | |
"mismatched_keys": mismatched_keys, | |
"error_msgs": error_msgs, | |
} | |
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): | |
raise ValueError( | |
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." | |
) | |
elif torch_dtype is not None: | |
model = model.to(torch_dtype) | |
model.register_to_config(_name_or_path=pretrained_model_name_or_path) | |
# Set model in evaluation mode to deactivate DropOut modules by default | |
model.eval() | |
if output_loading_info: | |
return model, loading_info | |
return model | |
@classmethod | |
def _load_pretrained_model( | |
cls, | |
model, | |
state_dict: OrderedDict, | |
resolved_archive_file, | |
pretrained_model_name_or_path: Union[str, os.PathLike], | |
ignore_mismatched_sizes: bool = False, | |
): | |
# Retrieve missing & unexpected_keys | |
model_state_dict = model.state_dict() | |
loaded_keys = list(state_dict.keys()) | |
expected_keys = list(model_state_dict.keys()) | |
original_loaded_keys = loaded_keys | |
missing_keys = list(set(expected_keys) - set(loaded_keys)) | |
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) | |
# Make sure we are able to load base models as well as derived models (with heads) | |
model_to_load = model | |
def _find_mismatched_keys( | |
state_dict, | |
model_state_dict, | |
loaded_keys, | |
ignore_mismatched_sizes, | |
): | |
mismatched_keys = [] | |
if ignore_mismatched_sizes: | |
for checkpoint_key in loaded_keys: | |
model_key = checkpoint_key | |
if ( | |
model_key in model_state_dict | |
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape | |
): | |
mismatched_keys.append( | |
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | |
) | |
del state_dict[checkpoint_key] | |
return mismatched_keys | |
if state_dict is not None: | |
# Whole checkpoint | |
mismatched_keys = _find_mismatched_keys( | |
state_dict, | |
model_state_dict, | |
original_loaded_keys, | |
ignore_mismatched_sizes, | |
) | |
error_msgs = _load_state_dict_into_model(model_to_load, state_dict) | |
if len(error_msgs) > 0: | |
error_msg = "\n\t".join(error_msgs) | |
if "size mismatch" in error_msg: | |
error_msg += ( | |
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." | |
) | |
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | |
if len(unexpected_keys) > 0: | |
logger.warning( | |
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | |
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | |
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" | |
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a" | |
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" | |
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" | |
" identical (initializing a BertForSequenceClassification model from a" | |
" BertForSequenceClassification model)." | |
) | |
else: | |
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" | |
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
elif len(mismatched_keys) == 0: | |
logger.info( | |
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" | |
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" | |
" without further training." | |
) | |
if len(mismatched_keys) > 0: | |
mismatched_warning = "\n".join( | |
[ | |
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | |
for key, shape1, shape2 in mismatched_keys | |
] | |
) | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" | |
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" | |
" able to use it for predictions and inference." | |
) | |
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs | |
@property | |
def device(self) -> torch.device: | |
""" | |
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same | |
device). | |
""" | |
return get_parameter_device(self) | |
@property | |
def dtype(self) -> torch.dtype: | |
""" | |
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |
""" | |
return get_parameter_dtype(self) | |
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: | |
""" | |
Get number of (trainable or non-embedding) parameters in the module. | |
Args: | |
only_trainable (`bool`, *optional*, defaults to `False`): | |
Whether or not to return only the number of trainable parameters. | |
exclude_embeddings (`bool`, *optional*, defaults to `False`): | |
Whether or not to return only the number of non-embedding parameters. | |
Returns: | |
`int`: The number of parameters. | |
Example: | |
```py | |
from diffusers import UNet2DConditionModel | |
model_id = "runwayml/stable-diffusion-v1-5" | |
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") | |
unet.num_parameters(only_trainable=True) | |
859520964 | |
``` | |
""" | |
if exclude_embeddings: | |
embedding_param_names = [ | |
f"{name}.weight" | |
for name, module_type in self.named_modules() | |
if isinstance(module_type, torch.nn.Embedding) | |
] | |
non_embedding_parameters = [ | |
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names | |
] | |
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) | |
else: | |
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) | |
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: | |
deprecated_attention_block_paths = [] | |
def recursive_find_attn_block(name, module): | |
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: | |
deprecated_attention_block_paths.append(name) | |
for sub_name, sub_module in module.named_children(): | |
sub_name = sub_name if name == "" else f"{name}.{sub_name}" | |
recursive_find_attn_block(sub_name, sub_module) | |
recursive_find_attn_block("", self) | |
# NOTE: we have to check if the deprecated parameters are in the state dict | |
# because it is possible we are loading from a state dict that was already | |
# converted | |
for path in deprecated_attention_block_paths: | |
# group_norm path stays the same | |
# query -> to_q | |
if f"{path}.query.weight" in state_dict: | |
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") | |
if f"{path}.query.bias" in state_dict: | |
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") | |
# key -> to_k | |
if f"{path}.key.weight" in state_dict: | |
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") | |
if f"{path}.key.bias" in state_dict: | |
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") | |
# value -> to_v | |
if f"{path}.value.weight" in state_dict: | |
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") | |
if f"{path}.value.bias" in state_dict: | |
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") | |
# proj_attn -> to_out.0 | |
if f"{path}.proj_attn.weight" in state_dict: | |
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") | |
if f"{path}.proj_attn.bias" in state_dict: | |
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") | |
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: | |
deprecated_attention_block_modules = [] | |
def recursive_find_attn_block(module): | |
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: | |
deprecated_attention_block_modules.append(module) | |
for sub_module in module.children(): | |
recursive_find_attn_block(sub_module) | |
recursive_find_attn_block(self) | |
for module in deprecated_attention_block_modules: | |
module.query = module.to_q | |
module.key = module.to_k | |
module.value = module.to_v | |
module.proj_attn = module.to_out[0] | |
# We don't _have_ to delete the old attributes, but it's helpful to ensure | |
# that _all_ the weights are loaded into the new attributes and we're not | |
# making an incorrect assumption that this model should be converted when | |
# it really shouldn't be. | |
del module.to_q | |
del module.to_k | |
del module.to_v | |
del module.to_out | |
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: | |
deprecated_attention_block_modules = [] | |
def recursive_find_attn_block(module) -> None: | |
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: | |
deprecated_attention_block_modules.append(module) | |
for sub_module in module.children(): | |
recursive_find_attn_block(sub_module) | |
recursive_find_attn_block(self) | |
for module in deprecated_attention_block_modules: | |
module.to_q = module.query | |
module.to_k = module.key | |
module.to_v = module.value | |
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) | |
del module.query | |
del module.key | |
del module.value | |
del module.proj_attn | |
from typing import Any, Dict, Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..utils import USE_PEFT_BACKEND | |
from ..utils.torch_utils import maybe_allow_in_graph | |
from .activations import GEGLU, GELU, ApproximateGELU | |
from .attention_processor import Attention | |
from .embeddings import SinusoidalPositionalEmbedding | |
from .lora import LoRACompatibleLinear | |
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm | |
def _chunked_feed_forward( | |
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None | |
): | |
# "feed_forward_chunk_size" can be used to save memory | |
if hidden_states.shape[chunk_dim] % chunk_size != 0: | |
raise ValueError( | |
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
) | |
num_chunks = hidden_states.shape[chunk_dim] // chunk_size | |
if lora_scale is None: | |
ff_output = torch.cat( | |
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], | |
dim=chunk_dim, | |
) | |
else: | |
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete | |
ff_output = torch.cat( | |
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], | |
dim=chunk_dim, | |
) | |
return ff_output | |
@maybe_allow_in_graph | |
class GatedSelfAttentionDense(nn.Module): | |
r""" | |
A gated self-attention dense layer that combines visual features and object features. | |
Parameters: | |
query_dim (`int`): The number of channels in the query. | |
context_dim (`int`): The number of channels in the context. | |
n_heads (`int`): The number of heads to use for attention. | |
d_head (`int`): The number of channels in each head. | |
""" | |
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): | |
super().__init__() | |
# we need a linear projection since we need cat visual feature and obj feature | |
self.linear = nn.Linear(context_dim, query_dim) | |
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) | |
self.ff = FeedForward(query_dim, activation_fn="geglu") | |
self.norm1 = nn.LayerNorm(query_dim) | |
self.norm2 = nn.LayerNorm(query_dim) | |
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) | |
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) | |
self.enabled = True | |
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: | |
if not self.enabled: | |
return x | |
n_visual = x.shape[1] | |
objs = self.linear(objs) | |
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] | |
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) | |
return x | |
@maybe_allow_in_graph | |
class BasicTransformerBlock(nn.Module): | |
r""" | |
A basic Transformer block. | |
Parameters: | |
dim (`int`): The number of channels in the input and output. | |
num_attention_heads (`int`): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`): The number of channels in each head. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
num_embeds_ada_norm (: | |
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. | |
attention_bias (: | |
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. | |
only_cross_attention (`bool`, *optional*): | |
Whether to use only cross-attention layers. In this case two cross attention layers are used. | |
double_self_attention (`bool`, *optional*): | |
Whether to use two self-attention layers. In this case no cross attention layers are used. | |
upcast_attention (`bool`, *optional*): | |
Whether to upcast the attention computation to float32. This is useful for mixed precision training. | |
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): | |
Whether to use learnable elementwise affine parameters for normalization. | |
norm_type (`str`, *optional*, defaults to `"layer_norm"`): | |
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. | |
final_dropout (`bool` *optional*, defaults to False): | |
Whether to apply a final dropout after the last feed-forward layer. | |
attention_type (`str`, *optional*, defaults to `"default"`): | |
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. | |
positional_embeddings (`str`, *optional*, defaults to `None`): | |
The type of positional embeddings to apply to. | |
num_positional_embeddings (`int`, *optional*, defaults to `None`): | |
The maximum number of positional embeddings to apply. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
dropout=0.0, | |
cross_attention_dim: Optional[int] = None, | |
activation_fn: str = "geglu", | |
num_embeds_ada_norm: Optional[int] = None, | |
attention_bias: bool = False, | |
only_cross_attention: bool = False, | |
double_self_attention: bool = False, | |
upcast_attention: bool = False, | |
norm_elementwise_affine: bool = True, | |
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' | |
norm_eps: float = 1e-5, | |
final_dropout: bool = False, | |
attention_type: str = "default", | |
positional_embeddings: Optional[str] = None, | |
num_positional_embeddings: Optional[int] = None, | |
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, | |
ada_norm_bias: Optional[int] = None, | |
ff_inner_dim: Optional[int] = None, | |
ff_bias: bool = True, | |
attention_out_bias: bool = True, | |
): | |
super().__init__() | |
self.only_cross_attention = only_cross_attention | |
# We keep these boolean flags for backward-compatibility. | |
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" | |
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" | |
self.use_ada_layer_norm_single = norm_type == "ada_norm_single" | |
self.use_layer_norm = norm_type == "layer_norm" | |
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" | |
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: | |
raise ValueError( | |
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" | |
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." | |
) | |
self.norm_type = norm_type | |
self.num_embeds_ada_norm = num_embeds_ada_norm | |
if positional_embeddings and (num_positional_embeddings is None): | |
raise ValueError( | |
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." | |
) | |
if positional_embeddings == "sinusoidal": | |
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) | |
else: | |
self.pos_embed = None | |
# Define 3 blocks. Each block has its own normalization layer. | |
# 1. Self-Attn | |
if norm_type == "ada_norm": | |
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) | |
elif norm_type == "ada_norm_zero": | |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) | |
elif norm_type == "ada_norm_continuous": | |
self.norm1 = AdaLayerNormContinuous( | |
dim, | |
ada_norm_continous_conditioning_embedding_dim, | |
norm_elementwise_affine, | |
norm_eps, | |
ada_norm_bias, | |
"rms_norm", | |
) | |
else: | |
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | |
self.attn1 = Attention( | |
query_dim=dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
upcast_attention=upcast_attention, | |
out_bias=attention_out_bias, | |
) | |
# 2. Cross-Attn | |
if cross_attention_dim is not None or double_self_attention: | |
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block. | |
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during | |
# the second cross attention block. | |
if norm_type == "ada_norm": | |
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) | |
elif norm_type == "ada_norm_continuous": | |
self.norm2 = AdaLayerNormContinuous( | |
dim, | |
ada_norm_continous_conditioning_embedding_dim, | |
norm_elementwise_affine, | |
norm_eps, | |
ada_norm_bias, | |
"rms_norm", | |
) | |
else: | |
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) | |
self.attn2 = Attention( | |
query_dim=dim, | |
cross_attention_dim=cross_attention_dim if not double_self_attention else None, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
upcast_attention=upcast_attention, | |
out_bias=attention_out_bias, | |
) # is self-attn if encoder_hidden_states is none | |
else: | |
self.norm2 = None | |
self.attn2 = None | |
# 3. Feed-forward | |
if norm_type == "ada_norm_continuous": | |
self.norm3 = AdaLayerNormContinuous( | |
dim, | |
ada_norm_continous_conditioning_embedding_dim, | |
norm_elementwise_affine, | |
norm_eps, | |
ada_norm_bias, | |
"layer_norm", | |
) | |
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: | |
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) | |
elif norm_type == "layer_norm_i2vgen": | |
self.norm3 = None | |
self.ff = FeedForward( | |
dim, | |
dropout=dropout, | |
activation_fn=activation_fn, | |
final_dropout=final_dropout, | |
inner_dim=ff_inner_dim, | |
bias=ff_bias, | |
) | |
# 4. Fuser | |
if attention_type == "gated" or attention_type == "gated-text-image": | |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) | |
# 5. Scale-shift for PixArt-Alpha. | |
if norm_type == "ada_norm_single": | |
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) | |
# let chunk size default to None | |
self._chunk_size = None | |
self._chunk_dim = 0 | |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): | |
# Sets chunk feed-forward | |
self._chunk_size = chunk_size | |
self._chunk_dim = dim | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
) -> torch.FloatTensor: | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
# 0. Self-Attention | |
batch_size = hidden_states.shape[0] | |
if self.norm_type == "ada_norm": | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.norm_type == "ada_norm_zero": | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
) | |
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: | |
norm_hidden_states = self.norm1(hidden_states) | |
elif self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
elif self.norm_type == "ada_norm_single": | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) | |
).chunk(6, dim=1) | |
norm_hidden_states = self.norm1(hidden_states) | |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa | |
norm_hidden_states = norm_hidden_states.squeeze(1) | |
else: | |
raise ValueError("Incorrect norm used") | |
if self.pos_embed is not None: | |
norm_hidden_states = self.pos_embed(norm_hidden_states) | |
# 1. Retrieve lora scale. | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
# 2. Prepare GLIGEN inputs | |
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.norm_type == "ada_norm_zero": | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
elif self.norm_type == "ada_norm_single": | |
attn_output = gate_msa * attn_output | |
hidden_states = attn_output + hidden_states | |
if hidden_states.ndim == 4: | |
hidden_states = hidden_states.squeeze(1) | |
# 2.5 GLIGEN Control | |
if gligen_kwargs is not None: | |
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
# 3. Cross-Attention | |
if self.attn2 is not None: | |
if self.norm_type == "ada_norm": | |
norm_hidden_states = self.norm2(hidden_states, timestep) | |
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: | |
norm_hidden_states = self.norm2(hidden_states) | |
elif self.norm_type == "ada_norm_single": | |
# For PixArt norm2 isn't applied here: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 | |
norm_hidden_states = hidden_states | |
elif self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
else: | |
raise ValueError("Incorrect norm") | |
if self.pos_embed is not None and self.norm_type != "ada_norm_single": | |
norm_hidden_states = self.pos_embed(norm_hidden_states) | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# 4. Feed-forward | |
# i2vgen doesn't have this norm 🤷♂️ | |
if self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
elif not self.norm_type == "ada_norm_single": | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.norm_type == "ada_norm_zero": | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
if self.norm_type == "ada_norm_single": | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp | |
if self._chunk_size is not None: | |
# "feed_forward_chunk_size" can be used to save memory | |
ff_output = _chunked_feed_forward( | |
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale | |
) | |
else: | |
ff_output = self.ff(norm_hidden_states, scale=lora_scale) | |
if self.norm_type == "ada_norm_zero": | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
elif self.norm_type == "ada_norm_single": | |
ff_output = gate_mlp * ff_output | |
hidden_states = ff_output + hidden_states | |
if hidden_states.ndim == 4: | |
hidden_states = hidden_states.squeeze(1) | |
return hidden_states | |
@maybe_allow_in_graph | |
class TemporalBasicTransformerBlock(nn.Module): | |
r""" | |
A basic Transformer block for video like data. | |
Parameters: | |
dim (`int`): The number of channels in the input and output. | |
time_mix_inner_dim (`int`): The number of channels for temporal attention. | |
num_attention_heads (`int`): The number of heads to use for multi-head attention. | |
attention_head_dim (`int`): The number of channels in each head. | |
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
time_mix_inner_dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
cross_attention_dim: Optional[int] = None, | |
): | |
super().__init__() | |
self.is_res = dim == time_mix_inner_dim | |
self.norm_in = nn.LayerNorm(dim) | |
# Define 3 blocks. Each block has its own normalization layer. | |
# 1. Self-Attn | |
self.ff_in = FeedForward( | |
dim, | |
dim_out=time_mix_inner_dim, | |
activation_fn="geglu", | |
) | |
self.norm1 = nn.LayerNorm(time_mix_inner_dim) | |
self.attn1 = Attention( | |
query_dim=time_mix_inner_dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
cross_attention_dim=None, | |
) | |
# 2. Cross-Attn | |
if cross_attention_dim is not None: | |
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block. | |
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during | |
# the second cross attention block. | |
self.norm2 = nn.LayerNorm(time_mix_inner_dim) | |
self.attn2 = Attention( | |
query_dim=time_mix_inner_dim, | |
cross_attention_dim=cross_attention_dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
) # is self-attn if encoder_hidden_states is none | |
else: | |
self.norm2 = None | |
self.attn2 = None | |
# 3. Feed-forward | |
self.norm3 = nn.LayerNorm(time_mix_inner_dim) | |
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") | |
# let chunk size default to None | |
self._chunk_size = None | |
self._chunk_dim = None | |
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): | |
# Sets chunk feed-forward | |
self._chunk_size = chunk_size | |
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off | |
self._chunk_dim = 1 | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
num_frames: int, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
# 0. Self-Attention | |
batch_size = hidden_states.shape[0] | |
batch_frames, seq_length, channels = hidden_states.shape | |
batch_size = batch_frames // num_frames | |
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) | |
hidden_states = hidden_states.permute(0, 2, 1, 3) | |
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) | |
residual = hidden_states | |
hidden_states = self.norm_in(hidden_states) | |
if self._chunk_size is not None: | |
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) | |
else: | |
hidden_states = self.ff_in(hidden_states) | |
if self.is_res: | |
hidden_states = hidden_states + residual | |
norm_hidden_states = self.norm1(hidden_states) | |
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) | |
hidden_states = attn_output + hidden_states | |
# 3. Cross-Attention | |
if self.attn2 is not None: | |
norm_hidden_states = self.norm2(hidden_states) | |
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) | |
hidden_states = attn_output + hidden_states | |
# 4. Feed-forward | |
norm_hidden_states = self.norm3(hidden_states) | |
if self._chunk_size is not None: | |
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) | |
else: | |
ff_output = self.ff(norm_hidden_states) | |
if self.is_res: | |
hidden_states = ff_output + hidden_states | |
else: | |
hidden_states = ff_output | |
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) | |
hidden_states = hidden_states.permute(0, 2, 1, 3) | |
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) | |
return hidden_states | |
class SkipFFTransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
kv_input_dim: int, | |
kv_input_dim_proj_use_bias: bool, | |
dropout=0.0, | |
cross_attention_dim: Optional[int] = None, | |
attention_bias: bool = False, | |
attention_out_bias: bool = True, | |
): | |
super().__init__() | |
if kv_input_dim != dim: | |
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) | |
else: | |
self.kv_mapper = None | |
self.norm1 = RMSNorm(dim, 1e-06) | |
self.attn1 = Attention( | |
query_dim=dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
cross_attention_dim=cross_attention_dim, | |
out_bias=attention_out_bias, | |
) | |
self.norm2 = RMSNorm(dim, 1e-06) | |
self.attn2 = Attention( | |
query_dim=dim, | |
cross_attention_dim=cross_attention_dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
out_bias=attention_out_bias, | |
) | |
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): | |
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
if self.kv_mapper is not None: | |
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) | |
norm_hidden_states = self.norm1(hidden_states) | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
norm_hidden_states = self.norm2(hidden_states) | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
return hidden_states | |
class FeedForward(nn.Module): | |
r""" | |
A feed-forward layer. | |
Parameters: | |
dim (`int`): The number of channels in the input. | |
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. | |
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. | |
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
dim_out: Optional[int] = None, | |
mult: int = 4, | |
dropout: float = 0.0, | |
activation_fn: str = "geglu", | |
final_dropout: bool = False, | |
inner_dim=None, | |
bias: bool = True, | |
): | |
super().__init__() | |
if inner_dim is None: | |
inner_dim = int(dim * mult) | |
dim_out = dim_out if dim_out is not None else dim | |
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear | |
if activation_fn == "gelu": | |
act_fn = GELU(dim, inner_dim, bias=bias) | |
if activation_fn == "gelu-approximate": | |
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) | |
elif activation_fn == "geglu": | |
act_fn = GEGLU(dim, inner_dim, bias=bias) | |
elif activation_fn == "geglu-approximate": | |
act_fn = ApproximateGELU(dim, inner_dim, bias=bias) | |
self.net = nn.ModuleList([]) | |
# project in | |
self.net.append(act_fn) | |
# project dropout | |
self.net.append(nn.Dropout(dropout)) | |
# project out | |
self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) | |
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout | |
if final_dropout: | |
self.net.append(nn.Dropout(dropout)) | |
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: | |
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) | |
for module in self.net: | |
if isinstance(module, compatible_cls): | |
hidden_states = module(hidden_states, scale) | |
else: | |
hidden_states = module(hidden_states) | |
return hidden_states | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from ..configuration_utils import ConfigMixin, register_to_config | |
from ..utils import BaseOutput | |
from ..utils.accelerate_utils import apply_forward_hook | |
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer | |
from .modeling_utils import ModelMixin | |
@dataclass | |
class VQEncoderOutput(BaseOutput): | |
""" | |
Output of VQModel encoding method. | |
Args: | |
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
The encoded output sample from the last layer of the model. | |
""" | |
latents: torch.FloatTensor | |
class VQModel(ModelMixin, ConfigMixin): | |
r""" | |
A VQ-VAE model for decoding latent representations. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
in_channels (int, *optional*, defaults to 3): Number of channels in the input image. | |
out_channels (int, *optional*, defaults to 3): Number of channels in the output. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): | |
Tuple of downsample block types. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): | |
Tuple of upsample block types. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): | |
Tuple of block output channels. | |
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. | |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. | |
sample_size (`int`, *optional*, defaults to `32`): Sample input size. | |
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. | |
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. | |
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. | |
scaling_factor (`float`, *optional*, defaults to `0.18215`): | |
The component-wise standard deviation of the trained latent space computed using the first batch of the | |
training set. This is used to scale the latent space to have unit variance when training the diffusion | |
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the | |
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 | |
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image | |
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. | |
norm_type (`str`, *optional*, defaults to `"group"`): | |
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), | |
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), | |
block_out_channels: Tuple[int, ...] = (64,), | |
layers_per_block: int = 1, | |
act_fn: str = "silu", | |
latent_channels: int = 3, | |
sample_size: int = 32, | |
num_vq_embeddings: int = 256, | |
norm_num_groups: int = 32, | |
vq_embed_dim: Optional[int] = None, | |
scaling_factor: float = 0.18215, | |
norm_type: str = "group", # group, spatial | |
mid_block_add_attention=True, | |
lookup_from_codebook=False, | |
force_upcast=False, | |
): | |
super().__init__() | |
# pass init params to Encoder | |
self.encoder = Encoder( | |
in_channels=in_channels, | |
out_channels=latent_channels, | |
down_block_types=down_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
double_z=False, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels | |
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) | |
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) | |
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) | |
# pass init params to Decoder | |
self.decoder = Decoder( | |
in_channels=latent_channels, | |
out_channels=out_channels, | |
up_block_types=up_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_type=norm_type, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
@apply_forward_hook | |
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: | |
h = self.encoder(x) | |
h = self.quant_conv(h) | |
if not return_dict: | |
return (h,) | |
return VQEncoderOutput(latents=h) | |
@apply_forward_hook | |
def decode( | |
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None | |
) -> Union[DecoderOutput, torch.FloatTensor]: | |
# also go through quantization layer | |
if not force_not_quantize: | |
quant, _, _ = self.quantize(h) | |
elif self.config.lookup_from_codebook: | |
quant = self.quantize.get_codebook_entry(h, shape) | |
else: | |
quant = h | |
quant2 = self.post_quant_conv(quant) | |
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
def forward( | |
self, sample: torch.FloatTensor, return_dict: bool = True | |
) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]: | |
r""" | |
The [`VQModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): Input sample. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.vq_model.VQEncoderOutput`] or `tuple`: | |
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` | |
is returned. | |
""" | |
h = self.encode(sample).latents | |
dec = self.decode(h).sample | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch - Flax general utilities.""" | |
from pickle import UnpicklingError | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.serialization import from_bytes | |
from flax.traverse_util import flatten_dict | |
from ..utils import logging | |
logger = logging.get_logger(__name__) | |
##################### | |
# Flax => PyTorch # | |
##################### | |
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 | |
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): | |
try: | |
with open(model_file, "rb") as flax_state_f: | |
flax_state = from_bytes(None, flax_state_f.read()) | |
except UnpicklingError as e: | |
try: | |
with open(model_file) as f: | |
if f.read().startswith("version"): | |
raise OSError( | |
"You seem to have cloned a repository without having git-lfs installed. Please" | |
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | |
" folder you cloned." | |
) | |
else: | |
raise ValueError from e | |
except (UnicodeDecodeError, ValueError): | |
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") | |
return load_flax_weights_in_pytorch_model(pt_model, flax_state) | |
def load_flax_weights_in_pytorch_model(pt_model, flax_state): | |
"""Load flax checkpoints in a PyTorch model""" | |
try: | |
import torch # noqa: F401 | |
except ImportError: | |
logger.error( | |
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" | |
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" | |
" instructions." | |
) | |
raise | |
# check if we have bf16 weights | |
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() | |
if any(is_type_bf16): | |
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 | |
# and bf16 is not fully supported in PT yet. | |
logger.warning( | |
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " | |
"before loading those in PyTorch model." | |
) | |
flax_state = jax.tree_util.tree_map( | |
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state | |
) | |
pt_model.base_model_prefix = "" | |
flax_state_dict = flatten_dict(flax_state, sep=".") | |
pt_model_dict = pt_model.state_dict() | |
# keep track of unexpected & missing keys | |
unexpected_keys = [] | |
missing_keys = set(pt_model_dict.keys()) | |
for flax_key_tuple, flax_tensor in flax_state_dict.items(): | |
flax_key_tuple_array = flax_key_tuple.split(".") | |
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: | |
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] | |
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) | |
elif flax_key_tuple_array[-1] == "kernel": | |
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] | |
flax_tensor = flax_tensor.T | |
elif flax_key_tuple_array[-1] == "scale": | |
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] | |
if "time_embedding" not in flax_key_tuple_array: | |
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): | |
flax_key_tuple_array[i] = ( | |
flax_key_tuple_string.replace("_0", ".0") | |
.replace("_1", ".1") | |
.replace("_2", ".2") | |
.replace("_3", ".3") | |
.replace("_4", ".4") | |
.replace("_5", ".5") | |
.replace("_6", ".6") | |
.replace("_7", ".7") | |
.replace("_8", ".8") | |
.replace("_9", ".9") | |
) | |
flax_key = ".".join(flax_key_tuple_array) | |
if flax_key in pt_model_dict: | |
if flax_tensor.shape != pt_model_dict[flax_key].shape: | |
raise ValueError( | |
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " | |
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." | |
) | |
else: | |
# add weight to pytorch dict | |
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor | |
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) | |
# remove from missing keys | |
missing_keys.remove(flax_key) | |
else: | |
# weight is not expected by PyTorch model | |
unexpected_keys.append(flax_key) | |
pt_model.load_state_dict(pt_model_dict) | |
# re-transform missing_keys to list | |
missing_keys = list(missing_keys) | |
if len(unexpected_keys) > 0: | |
logger.warning( | |
"Some weights of the Flax model were not used when initializing the PyTorch model" | |
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" | |
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" | |
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" | |
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" | |
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" | |
" FlaxBertForSequenceClassification model)." | |
) | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" | |
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" | |
" use it for predictions and inference." | |
) | |
return pt_model | |
import functools | |
import math | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): | |
"""Multi-head dot product attention with a limited number of queries.""" | |
num_kv, num_heads, k_features = key.shape[-3:] | |
v_features = value.shape[-1] | |
key_chunk_size = min(key_chunk_size, num_kv) | |
query = query / jnp.sqrt(k_features) | |
@functools.partial(jax.checkpoint, prevent_cse=False) | |
def summarize_chunk(query, key, value): | |
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) | |
max_score = jnp.max(attn_weights, axis=-1, keepdims=True) | |
max_score = jax.lax.stop_gradient(max_score) | |
exp_weights = jnp.exp(attn_weights - max_score) | |
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) | |
max_score = jnp.einsum("...qhk->...qh", max_score) | |
return (exp_values, exp_weights.sum(axis=-1), max_score) | |
def chunk_scanner(chunk_idx): | |
# julienne key array | |
key_chunk = jax.lax.dynamic_slice( | |
operand=key, | |
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] | |
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] | |
) | |
# julienne value array | |
value_chunk = jax.lax.dynamic_slice( | |
operand=value, | |
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] | |
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] | |
) | |
return summarize_chunk(query, key_chunk, value_chunk) | |
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) | |
global_max = jnp.max(chunk_max, axis=0, keepdims=True) | |
max_diffs = jnp.exp(chunk_max - global_max) | |
chunk_values *= jnp.expand_dims(max_diffs, axis=-1) | |
chunk_weights *= max_diffs | |
all_values = chunk_values.sum(axis=0) | |
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) | |
return all_values / all_weights | |
def jax_memory_efficient_attention( | |
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 | |
): | |
r""" | |
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 | |
https://github.com/AminRezaei0x443/memory-efficient-attention | |
Args: | |
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) | |
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) | |
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) | |
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): | |
numerical precision for computation | |
query_chunk_size (`int`, *optional*, defaults to 1024): | |
chunk size to divide query array value must divide query_length equally without remainder | |
key_chunk_size (`int`, *optional*, defaults to 4096): | |
chunk size to divide key and value array value must divide key_value_length equally without remainder | |
Returns: | |
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) | |
""" | |
num_q, num_heads, q_features = query.shape[-3:] | |
def chunk_scanner(chunk_idx, _): | |
# julienne query array | |
query_chunk = jax.lax.dynamic_slice( | |
operand=query, | |
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] | |
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] | |
) | |
return ( | |
chunk_idx + query_chunk_size, # unused ignore it | |
_query_chunk_attention( | |
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size | |
), | |
) | |
_, res = jax.lax.scan( | |
f=chunk_scanner, | |
init=0, | |
xs=None, | |
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter | |
) | |
return jnp.concatenate(res, axis=-3) # fuse the chunked result back | |
class FlaxAttention(nn.Module): | |
r""" | |
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 | |
Parameters: | |
query_dim (:obj:`int`): | |
Input hidden states dimension | |
heads (:obj:`int`, *optional*, defaults to 8): | |
Number of heads | |
dim_head (:obj:`int`, *optional*, defaults to 64): | |
Hidden states dimension inside each head | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
query_dim: int | |
heads: int = 8 | |
dim_head: int = 64 | |
dropout: float = 0.0 | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
inner_dim = self.dim_head * self.heads | |
self.scale = self.dim_head**-0.5 | |
# Weights were exported with old names {to_q, to_k, to_v, to_out} | |
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") | |
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") | |
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") | |
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") | |
self.dropout_layer = nn.Dropout(rate=self.dropout) | |
def reshape_heads_to_batch_dim(self, tensor): | |
batch_size, seq_len, dim = tensor.shape | |
head_size = self.heads | |
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) | |
tensor = jnp.transpose(tensor, (0, 2, 1, 3)) | |
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) | |
return tensor | |
def reshape_batch_dim_to_heads(self, tensor): | |
batch_size, seq_len, dim = tensor.shape | |
head_size = self.heads | |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) | |
tensor = jnp.transpose(tensor, (0, 2, 1, 3)) | |
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) | |
return tensor | |
def __call__(self, hidden_states, context=None, deterministic=True): | |
context = hidden_states if context is None else context | |
query_proj = self.query(hidden_states) | |
key_proj = self.key(context) | |
value_proj = self.value(context) | |
if self.split_head_dim: | |
b = hidden_states.shape[0] | |
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head)) | |
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head)) | |
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head)) | |
else: | |
query_states = self.reshape_heads_to_batch_dim(query_proj) | |
key_states = self.reshape_heads_to_batch_dim(key_proj) | |
value_states = self.reshape_heads_to_batch_dim(value_proj) | |
if self.use_memory_efficient_attention: | |
query_states = query_states.transpose(1, 0, 2) | |
key_states = key_states.transpose(1, 0, 2) | |
value_states = value_states.transpose(1, 0, 2) | |
# this if statement create a chunk size for each layer of the unet | |
# the chunk size is equal to the query_length dimension of the deepest layer of the unet | |
flatten_latent_dim = query_states.shape[-3] | |
if flatten_latent_dim % 64 == 0: | |
query_chunk_size = int(flatten_latent_dim / 64) | |
elif flatten_latent_dim % 16 == 0: | |
query_chunk_size = int(flatten_latent_dim / 16) | |
elif flatten_latent_dim % 4 == 0: | |
query_chunk_size = int(flatten_latent_dim / 4) | |
else: | |
query_chunk_size = int(flatten_latent_dim) | |
hidden_states = jax_memory_efficient_attention( | |
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 | |
) | |
hidden_states = hidden_states.transpose(1, 0, 2) | |
else: | |
# compute attentions | |
if self.split_head_dim: | |
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) | |
else: | |
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) | |
attention_scores = attention_scores * self.scale | |
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) | |
# attend to values | |
if self.split_head_dim: | |
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) | |
b = hidden_states.shape[0] | |
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) | |
else: | |
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) | |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | |
hidden_states = self.proj_attn(hidden_states) | |
return self.dropout_layer(hidden_states, deterministic=deterministic) | |
class FlaxBasicTransformerBlock(nn.Module): | |
r""" | |
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: | |
https://arxiv.org/abs/1706.03762 | |
Parameters: | |
dim (:obj:`int`): | |
Inner hidden states dimension | |
n_heads (:obj:`int`): | |
Number of heads | |
d_head (:obj:`int`): | |
Hidden states dimension inside each head | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
only_cross_attention (`bool`, defaults to `False`): | |
Whether to only apply cross attention. | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
""" | |
dim: int | |
n_heads: int | |
d_head: int | |
dropout: float = 0.0 | |
only_cross_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
def setup(self): | |
# self attention (or cross_attention if only_cross_attention is True) | |
self.attn1 = FlaxAttention( | |
self.dim, | |
self.n_heads, | |
self.d_head, | |
self.dropout, | |
self.use_memory_efficient_attention, | |
self.split_head_dim, | |
dtype=self.dtype, | |
) | |
# cross attention | |
self.attn2 = FlaxAttention( | |
self.dim, | |
self.n_heads, | |
self.d_head, | |
self.dropout, | |
self.use_memory_efficient_attention, | |
self.split_head_dim, | |
dtype=self.dtype, | |
) | |
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) | |
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) | |
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) | |
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) | |
self.dropout_layer = nn.Dropout(rate=self.dropout) | |
def __call__(self, hidden_states, context, deterministic=True): | |
# self attention | |
residual = hidden_states | |
if self.only_cross_attention: | |
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) | |
else: | |
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) | |
hidden_states = hidden_states + residual | |
# cross attention | |
residual = hidden_states | |
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) | |
hidden_states = hidden_states + residual | |
# feed forward | |
residual = hidden_states | |
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) | |
hidden_states = hidden_states + residual | |
return self.dropout_layer(hidden_states, deterministic=deterministic) | |
class FlaxTransformer2DModel(nn.Module): | |
r""" | |
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: | |
https://arxiv.org/pdf/1506.02025.pdf | |
Parameters: | |
in_channels (:obj:`int`): | |
Input number of channels | |
n_heads (:obj:`int`): | |
Number of heads | |
d_head (:obj:`int`): | |
Hidden states dimension inside each head | |
depth (:obj:`int`, *optional*, defaults to 1): | |
Number of transformers block | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
use_linear_projection (`bool`, defaults to `False`): tbd | |
only_cross_attention (`bool`, defaults to `False`): tbd | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
""" | |
in_channels: int | |
n_heads: int | |
d_head: int | |
depth: int = 1 | |
dropout: float = 0.0 | |
use_linear_projection: bool = False | |
only_cross_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
def setup(self): | |
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) | |
inner_dim = self.n_heads * self.d_head | |
if self.use_linear_projection: | |
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) | |
else: | |
self.proj_in = nn.Conv( | |
inner_dim, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
self.transformer_blocks = [ | |
FlaxBasicTransformerBlock( | |
inner_dim, | |
self.n_heads, | |
self.d_head, | |
dropout=self.dropout, | |
only_cross_attention=self.only_cross_attention, | |
dtype=self.dtype, | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
) | |
for _ in range(self.depth) | |
] | |
if self.use_linear_projection: | |
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) | |
else: | |
self.proj_out = nn.Conv( | |
inner_dim, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
self.dropout_layer = nn.Dropout(rate=self.dropout) | |
def __call__(self, hidden_states, context, deterministic=True): | |
batch, height, width, channels = hidden_states.shape | |
residual = hidden_states | |
hidden_states = self.norm(hidden_states) | |
if self.use_linear_projection: | |
hidden_states = hidden_states.reshape(batch, height * width, channels) | |
hidden_states = self.proj_in(hidden_states) | |
else: | |
hidden_states = self.proj_in(hidden_states) | |
hidden_states = hidden_states.reshape(batch, height * width, channels) | |
for transformer_block in self.transformer_blocks: | |
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) | |
if self.use_linear_projection: | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = hidden_states.reshape(batch, height, width, channels) | |
else: | |
hidden_states = hidden_states.reshape(batch, height, width, channels) | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = hidden_states + residual | |
return self.dropout_layer(hidden_states, deterministic=deterministic) | |
class FlaxFeedForward(nn.Module): | |
r""" | |
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's | |
[`FeedForward`] class, with the following simplifications: | |
- The activation function is currently hardcoded to a gated linear unit from: | |
https://arxiv.org/abs/2002.05202 | |
- `dim_out` is equal to `dim`. | |
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. | |
Parameters: | |
dim (:obj:`int`): | |
Inner hidden states dimension | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
dim: int | |
dropout: float = 0.0 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
# The second linear layer needs to be called | |
# net_2 for now to match the index of the Sequential layer | |
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) | |
self.net_2 = nn.Dense(self.dim, dtype=self.dtype) | |
def __call__(self, hidden_states, deterministic=True): | |
hidden_states = self.net_0(hidden_states, deterministic=deterministic) | |
hidden_states = self.net_2(hidden_states) | |
return hidden_states | |
class FlaxGEGLU(nn.Module): | |
r""" | |
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from | |
https://arxiv.org/abs/2002.05202. | |
Parameters: | |
dim (:obj:`int`): | |
Input hidden states dimension | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
dim: int | |
dropout: float = 0.0 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
inner_dim = self.dim * 4 | |
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) | |
self.dropout_layer = nn.Dropout(rate=self.dropout) | |
def __call__(self, hidden_states, deterministic=True): | |
hidden_states = self.proj(hidden_states) | |
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) | |
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) | |
from ..utils import deprecate | |
from .unets.unet_2d import UNet2DModel, UNet2DOutput | |
class UNet2DOutput(UNet2DOutput): | |
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead." | |
deprecate("UNet2DOutput", "0.29", deprecation_message) | |
class UNet2DModel(UNet2DModel): | |
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead." | |
deprecate("UNet2DModel", "0.29", deprecation_message) | |
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from functools import partial | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils import USE_PEFT_BACKEND | |
from .activations import get_activation | |
from .attention_processor import SpatialNorm | |
from .downsampling import ( # noqa | |
Downsample1D, | |
Downsample2D, | |
FirDownsample2D, | |
KDownsample2D, | |
downsample_2d, | |
) | |
from .lora import LoRACompatibleConv, LoRACompatibleLinear | |
from .normalization import AdaGroupNorm | |
from .upsampling import ( # noqa | |
FirUpsample2D, | |
KUpsample2D, | |
Upsample1D, | |
Upsample2D, | |
upfirdn2d_native, | |
upsample_2d, | |
) | |
class ResnetBlockCondNorm2D(nn.Module): | |
r""" | |
A Resnet block that use normalization layer that incorporate conditioning information. | |
Parameters: | |
in_channels (`int`): The number of channels in the input. | |
out_channels (`int`, *optional*, default to be `None`): | |
The number of output channels for the first conv2d layer. If None, same as `in_channels`. | |
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | |
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | |
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. | |
groups_out (`int`, *optional*, default to None): | |
The number of groups to use for the second normalization layer. if set to None, same as `groups`. | |
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | |
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. | |
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): | |
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". | |
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see | |
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. | |
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. | |
use_in_shortcut (`bool`, *optional*, default to `True`): | |
If `True`, add a 1x1 nn.conv2d layer for skip-connection. | |
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. | |
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. | |
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the | |
`conv_shortcut` output. | |
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. | |
If None, same as `out_channels`. | |
""" | |
def __init__( | |
self, | |
*, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
conv_shortcut: bool = False, | |
dropout: float = 0.0, | |
temb_channels: int = 512, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
eps: float = 1e-6, | |
non_linearity: str = "swish", | |
time_embedding_norm: str = "ada_group", # ada_group, spatial | |
output_scale_factor: float = 1.0, | |
use_in_shortcut: Optional[bool] = None, | |
up: bool = False, | |
down: bool = False, | |
conv_shortcut_bias: bool = True, | |
conv_2d_out_channels: Optional[int] = None, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.up = up | |
self.down = down | |
self.output_scale_factor = output_scale_factor | |
self.time_embedding_norm = time_embedding_norm | |
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | |
if groups_out is None: | |
groups_out = groups | |
if self.time_embedding_norm == "ada_group": # ada_group | |
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) | |
elif self.time_embedding_norm == "spatial": | |
self.norm1 = SpatialNorm(in_channels, temb_channels) | |
else: | |
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") | |
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
if self.time_embedding_norm == "ada_group": # ada_group | |
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) | |
elif self.time_embedding_norm == "spatial": # spatial | |
self.norm2 = SpatialNorm(out_channels, temb_channels) | |
else: | |
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") | |
self.dropout = torch.nn.Dropout(dropout) | |
conv_2d_out_channels = conv_2d_out_channels or out_channels | |
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = self.downsample = None | |
if self.up: | |
self.upsample = Upsample2D(in_channels, use_conv=False) | |
elif self.down: | |
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") | |
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut | |
self.conv_shortcut = None | |
if self.use_in_shortcut: | |
self.conv_shortcut = conv_cls( | |
in_channels, | |
conv_2d_out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=conv_shortcut_bias, | |
) | |
def forward( | |
self, | |
input_tensor: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
hidden_states = input_tensor | |
hidden_states = self.norm1(hidden_states, temb) | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.upsample is not None: | |
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
if hidden_states.shape[0] >= 64: | |
input_tensor = input_tensor.contiguous() | |
hidden_states = hidden_states.contiguous() | |
input_tensor = self.upsample(input_tensor, scale=scale) | |
hidden_states = self.upsample(hidden_states, scale=scale) | |
elif self.downsample is not None: | |
input_tensor = self.downsample(input_tensor, scale=scale) | |
hidden_states = self.downsample(hidden_states, scale=scale) | |
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) | |
hidden_states = self.norm2(hidden_states, temb) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
input_tensor = ( | |
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) | |
) | |
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
return output_tensor | |
class ResnetBlock2D(nn.Module): | |
r""" | |
A Resnet block. | |
Parameters: | |
in_channels (`int`): The number of channels in the input. | |
out_channels (`int`, *optional*, default to be `None`): | |
The number of output channels for the first conv2d layer. If None, same as `in_channels`. | |
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | |
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | |
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. | |
groups_out (`int`, *optional*, default to None): | |
The number of groups to use for the second normalization layer. if set to None, same as `groups`. | |
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | |
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. | |
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. | |
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" | |
for a stronger conditioning with scale and shift. | |
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see | |
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. | |
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. | |
use_in_shortcut (`bool`, *optional*, default to `True`): | |
If `True`, add a 1x1 nn.conv2d layer for skip-connection. | |
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. | |
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. | |
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the | |
`conv_shortcut` output. | |
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. | |
If None, same as `out_channels`. | |
""" | |
def __init__( | |
self, | |
*, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
conv_shortcut: bool = False, | |
dropout: float = 0.0, | |
temb_channels: int = 512, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
pre_norm: bool = True, | |
eps: float = 1e-6, | |
non_linearity: str = "swish", | |
skip_time_act: bool = False, | |
time_embedding_norm: str = "default", # default, scale_shift, | |
kernel: Optional[torch.FloatTensor] = None, | |
output_scale_factor: float = 1.0, | |
use_in_shortcut: Optional[bool] = None, | |
up: bool = False, | |
down: bool = False, | |
conv_shortcut_bias: bool = True, | |
conv_2d_out_channels: Optional[int] = None, | |
): | |
super().__init__() | |
if time_embedding_norm == "ada_group": | |
raise ValueError( | |
"This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", | |
) | |
if time_embedding_norm == "spatial": | |
raise ValueError( | |
"This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", | |
) | |
self.pre_norm = True | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.up = up | |
self.down = down | |
self.output_scale_factor = output_scale_factor | |
self.time_embedding_norm = time_embedding_norm | |
self.skip_time_act = skip_time_act | |
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | |
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv | |
if groups_out is None: | |
groups_out = groups | |
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) | |
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
if temb_channels is not None: | |
if self.time_embedding_norm == "default": | |
self.time_emb_proj = linear_cls(temb_channels, out_channels) | |
elif self.time_embedding_norm == "scale_shift": | |
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) | |
else: | |
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") | |
else: | |
self.time_emb_proj = None | |
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) | |
self.dropout = torch.nn.Dropout(dropout) | |
conv_2d_out_channels = conv_2d_out_channels or out_channels | |
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = self.downsample = None | |
if self.up: | |
if kernel == "fir": | |
fir_kernel = (1, 3, 3, 1) | |
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) | |
elif kernel == "sde_vp": | |
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") | |
else: | |
self.upsample = Upsample2D(in_channels, use_conv=False) | |
elif self.down: | |
if kernel == "fir": | |
fir_kernel = (1, 3, 3, 1) | |
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) | |
elif kernel == "sde_vp": | |
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) | |
else: | |
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") | |
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut | |
self.conv_shortcut = None | |
if self.use_in_shortcut: | |
self.conv_shortcut = conv_cls( | |
in_channels, | |
conv_2d_out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=conv_shortcut_bias, | |
) | |
def forward( | |
self, | |
input_tensor: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
scale: float = 1.0, | |
) -> torch.FloatTensor: | |
hidden_states = input_tensor | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.upsample is not None: | |
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
if hidden_states.shape[0] >= 64: | |
input_tensor = input_tensor.contiguous() | |
hidden_states = hidden_states.contiguous() | |
input_tensor = ( | |
self.upsample(input_tensor, scale=scale) | |
if isinstance(self.upsample, Upsample2D) | |
else self.upsample(input_tensor) | |
) | |
hidden_states = ( | |
self.upsample(hidden_states, scale=scale) | |
if isinstance(self.upsample, Upsample2D) | |
else self.upsample(hidden_states) | |
) | |
elif self.downsample is not None: | |
input_tensor = ( | |
self.downsample(input_tensor, scale=scale) | |
if isinstance(self.downsample, Downsample2D) | |
else self.downsample(input_tensor) | |
) | |
hidden_states = ( | |
self.downsample(hidden_states, scale=scale) | |
if isinstance(self.downsample, Downsample2D) | |
else self.downsample(hidden_states) | |
) | |
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) | |
if self.time_emb_proj is not None: | |
if not self.skip_time_act: | |
temb = self.nonlinearity(temb) | |
temb = ( | |
self.time_emb_proj(temb, scale)[:, :, None, None] | |
if not USE_PEFT_BACKEND | |
else self.time_emb_proj(temb)[:, :, None, None] | |
) | |
if self.time_embedding_norm == "default": | |
if temb is not None: | |
hidden_states = hidden_states + temb | |
hidden_states = self.norm2(hidden_states) | |
elif self.time_embedding_norm == "scale_shift": | |
if temb is None: | |
raise ValueError( | |
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" | |
) | |
time_scale, time_shift = torch.chunk(temb, 2, dim=1) | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = hidden_states * (1 + time_scale) + time_shift | |
else: | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
input_tensor = ( | |
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) | |
) | |
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
return output_tensor | |
# unet_rl.py | |
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: | |
if len(tensor.shape) == 2: | |
return tensor[:, :, None] | |
if len(tensor.shape) == 3: | |
return tensor[:, :, None, :] | |
elif len(tensor.shape) == 4: | |
return tensor[:, :, 0, :] | |
else: | |
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") | |
class Conv1dBlock(nn.Module): | |
""" | |
Conv1d --> GroupNorm --> Mish | |
Parameters: | |
inp_channels (`int`): Number of input channels. | |
out_channels (`int`): Number of output channels. | |
kernel_size (`int` or `tuple`): Size of the convolving kernel. | |
n_groups (`int`, default `8`): Number of groups to separate the channels into. | |
activation (`str`, defaults to `mish`): Name of the activation function. | |
""" | |
def __init__( | |
self, | |
inp_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
n_groups: int = 8, | |
activation: str = "mish", | |
): | |
super().__init__() | |
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) | |
self.group_norm = nn.GroupNorm(n_groups, out_channels) | |
self.mish = get_activation(activation) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
intermediate_repr = self.conv1d(inputs) | |
intermediate_repr = rearrange_dims(intermediate_repr) | |
intermediate_repr = self.group_norm(intermediate_repr) | |
intermediate_repr = rearrange_dims(intermediate_repr) | |
output = self.mish(intermediate_repr) | |
return output | |
# unet_rl.py | |
class ResidualTemporalBlock1D(nn.Module): | |
""" | |
Residual 1D block with temporal convolutions. | |
Parameters: | |
inp_channels (`int`): Number of input channels. | |
out_channels (`int`): Number of output channels. | |
embed_dim (`int`): Embedding dimension. | |
kernel_size (`int` or `tuple`): Size of the convolving kernel. | |
activation (`str`, defaults `mish`): It is possible to choose the right activation function. | |
""" | |
def __init__( | |
self, | |
inp_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
kernel_size: Union[int, Tuple[int, int]] = 5, | |
activation: str = "mish", | |
): | |
super().__init__() | |
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) | |
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) | |
self.time_emb_act = get_activation(activation) | |
self.time_emb = nn.Linear(embed_dim, out_channels) | |
self.residual_conv = ( | |
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() | |
) | |
def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
inputs : [ batch_size x inp_channels x horizon ] | |
t : [ batch_size x embed_dim ] | |
returns: | |
out : [ batch_size x out_channels x horizon ] | |
""" | |
t = self.time_emb_act(t) | |
t = self.time_emb(t) | |
out = self.conv_in(inputs) + rearrange_dims(t) | |
out = self.conv_out(out) | |
return out + self.residual_conv(inputs) | |
class TemporalConvLayer(nn.Module): | |
""" | |
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: | |
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 | |
Parameters: | |
in_dim (`int`): Number of input channels. | |
out_dim (`int`): Number of output channels. | |
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | |
""" | |
def __init__( | |
self, | |
in_dim: int, | |
out_dim: Optional[int] = None, | |
dropout: float = 0.0, | |
norm_num_groups: int = 32, | |
): | |
super().__init__() | |
out_dim = out_dim or in_dim | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
# conv layers | |
self.conv1 = nn.Sequential( | |
nn.GroupNorm(norm_num_groups, in_dim), | |
nn.SiLU(), | |
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), | |
) | |
self.conv2 = nn.Sequential( | |
nn.GroupNorm(norm_num_groups, out_dim), | |
nn.SiLU(), | |
nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), | |
) | |
self.conv3 = nn.Sequential( | |
nn.GroupNorm(norm_num_groups, out_dim), | |
nn.SiLU(), | |
nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), | |
) | |
self.conv4 = nn.Sequential( | |
nn.GroupNorm(norm_num_groups, out_dim), | |
nn.SiLU(), | |
nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), | |
) | |
# zero out the last layer params,so the conv block is identity | |
nn.init.zeros_(self.conv4[-1].weight) | |
nn.init.zeros_(self.conv4[-1].bias) | |
def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: | |
hidden_states = ( | |
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) | |
) | |
identity = hidden_states | |
hidden_states = self.conv1(hidden_states) | |
hidden_states = self.conv2(hidden_states) | |
hidden_states = self.conv3(hidden_states) | |
hidden_states = self.conv4(hidden_states) | |
hidden_states = identity + hidden_states | |
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( | |
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] | |
) | |
return hidden_states | |
class TemporalResnetBlock(nn.Module): | |
r""" | |
A Resnet block. | |
Parameters: | |
in_channels (`int`): The number of channels in the input. | |
out_channels (`int`, *optional*, default to be `None`): | |
The number of output channels for the first conv2d layer. If None, same as `in_channels`. | |
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | |
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
temb_channels: int = 512, | |
eps: float = 1e-6, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
kernel_size = (3, 1, 1) | |
padding = [k // 2 for k in kernel_size] | |
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True) | |
self.conv1 = nn.Conv3d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
) | |
if temb_channels is not None: | |
self.time_emb_proj = nn.Linear(temb_channels, out_channels) | |
else: | |
self.time_emb_proj = None | |
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True) | |
self.dropout = torch.nn.Dropout(0.0) | |
self.conv2 = nn.Conv3d( | |
out_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
) | |
self.nonlinearity = get_activation("silu") | |
self.use_in_shortcut = self.in_channels != out_channels | |
self.conv_shortcut = None | |
if self.use_in_shortcut: | |
self.conv_shortcut = nn.Conv3d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = input_tensor | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
if self.time_emb_proj is not None: | |
temb = self.nonlinearity(temb) | |
temb = self.time_emb_proj(temb)[:, :, :, None, None] | |
temb = temb.permute(0, 2, 1, 3, 4) | |
hidden_states = hidden_states + temb | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
input_tensor = self.conv_shortcut(input_tensor) | |
output_tensor = input_tensor + hidden_states | |
return output_tensor | |
# VideoResBlock | |
class SpatioTemporalResBlock(nn.Module): | |
r""" | |
A SpatioTemporal Resnet block. | |
Parameters: | |
in_channels (`int`): The number of channels in the input. | |
out_channels (`int`, *optional*, default to be `None`): | |
The number of output channels for the first conv2d layer. If None, same as `in_channels`. | |
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | |
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet. | |
temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet. | |
merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing. | |
merge_strategy (`str`, *optional*, defaults to `learned_with_images`): | |
The merge strategy to use for the temporal mixing. | |
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): | |
If `True`, switch the spatial and temporal mixing. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
temb_channels: int = 512, | |
eps: float = 1e-6, | |
temporal_eps: Optional[float] = None, | |
merge_factor: float = 0.5, | |
merge_strategy="learned_with_images", | |
switch_spatial_to_temporal_mix: bool = False, | |
): | |
super().__init__() | |
self.spatial_res_block = ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=eps, | |
) | |
self.temporal_res_block = TemporalResnetBlock( | |
in_channels=out_channels if out_channels is not None else in_channels, | |
out_channels=out_channels if out_channels is not None else in_channels, | |
temb_channels=temb_channels, | |
eps=temporal_eps if temporal_eps is not None else eps, | |
) | |
self.time_mixer = AlphaBlender( | |
alpha=merge_factor, | |
merge_strategy=merge_strategy, | |
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix, | |
) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
): | |
num_frames = image_only_indicator.shape[-1] | |
hidden_states = self.spatial_res_block(hidden_states, temb) | |
batch_frames, channels, height, width = hidden_states.shape | |
batch_size = batch_frames // num_frames | |
hidden_states_mix = ( | |
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) | |
) | |
hidden_states = ( | |
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) | |
) | |
if temb is not None: | |
temb = temb.reshape(batch_size, num_frames, -1) | |
hidden_states = self.temporal_res_block(hidden_states, temb) | |
hidden_states = self.time_mixer( | |
x_spatial=hidden_states_mix, | |
x_temporal=hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) | |
return hidden_states | |
class AlphaBlender(nn.Module): | |
r""" | |
A module to blend spatial and temporal features. | |
Parameters: | |
alpha (`float`): The initial value of the blending factor. | |
merge_strategy (`str`, *optional*, defaults to `learned_with_images`): | |
The merge strategy to use for the temporal mixing. | |
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): | |
If `True`, switch the spatial and temporal mixing. | |
""" | |
strategies = ["learned", "fixed", "learned_with_images"] | |
def __init__( | |
self, | |
alpha: float, | |
merge_strategy: str = "learned_with_images", | |
switch_spatial_to_temporal_mix: bool = False, | |
): | |
super().__init__() | |
self.merge_strategy = merge_strategy | |
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE | |
if merge_strategy not in self.strategies: | |
raise ValueError(f"merge_strategy needs to be in {self.strategies}") | |
if self.merge_strategy == "fixed": | |
self.register_buffer("mix_factor", torch.Tensor([alpha])) | |
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": | |
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) | |
else: | |
raise ValueError(f"Unknown merge strategy {self.merge_strategy}") | |
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor: | |
if self.merge_strategy == "fixed": | |
alpha = self.mix_factor | |
elif self.merge_strategy == "learned": | |
alpha = torch.sigmoid(self.mix_factor) | |
elif self.merge_strategy == "learned_with_images": | |
if image_only_indicator is None: | |
raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy") | |
alpha = torch.where( | |
image_only_indicator.bool(), | |
torch.ones(1, 1, device=image_only_indicator.device), | |
torch.sigmoid(self.mix_factor)[..., None], | |
) | |
# (batch, channel, frames, height, width) | |
if ndims == 5: | |
alpha = alpha[:, None, :, None, None] | |
# (batch*frames, height*width, channels) | |
elif ndims == 3: | |
alpha = alpha.reshape(-1)[:, None, None] | |
else: | |
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5") | |
else: | |
raise NotImplementedError | |
return alpha | |
def forward( | |
self, | |
x_spatial: torch.Tensor, | |
x_temporal: torch.Tensor, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) | |
alpha = alpha.to(x_spatial.dtype) | |
if self.switch_spatial_to_temporal_mix: | |
alpha = 1.0 - alpha | |
x = alpha * x_spatial + (1.0 - alpha) * x_temporal | |
return x | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from pickle import UnpicklingError | |
from typing import Any, Dict, Union | |
import jax | |
import jax.numpy as jnp | |
import msgpack.exceptions | |
from flax.core.frozen_dict import FrozenDict, unfreeze | |
from flax.serialization import from_bytes, to_bytes | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from huggingface_hub import create_repo, hf_hub_download | |
from huggingface_hub.utils import ( | |
EntryNotFoundError, | |
RepositoryNotFoundError, | |
RevisionNotFoundError, | |
validate_hf_hub_args, | |
) | |
from requests import HTTPError | |
from .. import __version__, is_torch_available | |
from ..utils import ( | |
CONFIG_NAME, | |
FLAX_WEIGHTS_NAME, | |
HUGGINGFACE_CO_RESOLVE_ENDPOINT, | |
WEIGHTS_NAME, | |
PushToHubMixin, | |
logging, | |
) | |
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax | |
logger = logging.get_logger(__name__) | |
class FlaxModelMixin(PushToHubMixin): | |
r""" | |
Base class for all Flax models. | |
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and | |
saving models. | |
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. | |
""" | |
config_name = CONFIG_NAME | |
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] | |
_flax_internal_args = ["name", "parent", "dtype"] | |
@classmethod | |
def _from_config(cls, config, **kwargs): | |
""" | |
All context managers that the model should be initialized under go here. | |
""" | |
return cls(config, **kwargs) | |
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: | |
""" | |
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. | |
""" | |
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 | |
def conditional_cast(param): | |
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | |
param = param.astype(dtype) | |
return param | |
if mask is None: | |
return jax.tree_map(conditional_cast, params) | |
flat_params = flatten_dict(params) | |
flat_mask, _ = jax.tree_flatten(mask) | |
for masked, key in zip(flat_mask, flat_params.keys()): | |
if masked: | |
param = flat_params[key] | |
flat_params[key] = conditional_cast(param) | |
return unflatten_dict(flat_params) | |
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast | |
the `params` in place. | |
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full | |
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` | |
for params you want to cast, and `False` for those you want to skip. | |
Examples: | |
```python | |
>>> from diffusers import FlaxUNet2DConditionModel | |
>>> # load model | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision | |
>>> params = model.to_bf16(params) | |
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale) | |
>>> # then pass the mask as follows | |
>>> from flax import traverse_util | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> flat_params = traverse_util.flatten_dict(params) | |
>>> mask = { | |
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) | |
... for path in flat_params | |
... } | |
>>> mask = traverse_util.unflatten_dict(mask) | |
>>> params = model.to_bf16(params, mask) | |
```""" | |
return self._cast_floating_to(params, jnp.bfloat16, mask) | |
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the | |
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` | |
for params you want to cast, and `False` for those you want to skip. | |
Examples: | |
```python | |
>>> from diffusers import FlaxUNet2DConditionModel | |
>>> # Download model and configuration from huggingface.co | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> # By default, the model params will be in fp32, to illustrate the use of this method, | |
>>> # we'll first cast to fp16 and back to fp32 | |
>>> params = model.to_f16(params) | |
>>> # now cast back to fp32 | |
>>> params = model.to_fp32(params) | |
```""" | |
return self._cast_floating_to(params, jnp.float32, mask) | |
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the | |
`params` in place. | |
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full | |
half-precision training or to save weights in float16 for inference in order to save memory and improve speed. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` | |
for params you want to cast, and `False` for those you want to skip. | |
Examples: | |
```python | |
>>> from diffusers import FlaxUNet2DConditionModel | |
>>> # load model | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> # By default, the model params will be in fp32, to cast these to float16 | |
>>> params = model.to_fp16(params) | |
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) | |
>>> # then pass the mask as follows | |
>>> from flax import traverse_util | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> flat_params = traverse_util.flatten_dict(params) | |
>>> mask = { | |
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) | |
... for path in flat_params | |
... } | |
>>> mask = traverse_util.unflatten_dict(mask) | |
>>> params = model.to_fp16(params, mask) | |
```""" | |
return self._cast_floating_to(params, jnp.float16, mask) | |
def init_weights(self, rng: jax.Array) -> Dict: | |
raise NotImplementedError(f"init_weights method has to be implemented for {self}") | |
@classmethod | |
@validate_hf_hub_args | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: Union[str, os.PathLike], | |
dtype: jnp.dtype = jnp.float32, | |
*model_args, | |
**kwargs, | |
): | |
r""" | |
Instantiate a pretrained Flax model from a pretrained model configuration. | |
Parameters: | |
pretrained_model_name_or_path (`str` or `os.PathLike`): | |
Can be either: | |
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model | |
hosted on the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
using [`~FlaxModelMixin.save_pretrained`]. | |
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): | |
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and | |
`jax.numpy.bfloat16` (on TPUs). | |
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If | |
specified, all the computation will be performed with the given `dtype`. | |
<Tip> | |
This only specifies the dtype of the *computation* and does not influence the dtype of model | |
parameters. | |
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and | |
[`~FlaxModelMixin.to_bf16`]. | |
</Tip> | |
model_args (sequence of positional arguments, *optional*): | |
All remaining positional arguments are passed to the underlying model's `__init__` method. | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
is not used. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | |
incompletely downloaded files are deleted. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
local_files_only(`bool`, *optional*, defaults to `False`): | |
Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
won't be downloaded from the Hub. | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
allowed by Git. | |
from_pt (`bool`, *optional*, defaults to `False`): | |
Load the model weights from a PyTorch checkpoint save file. | |
kwargs (remaining dictionary of keyword arguments, *optional*): | |
Can be used to update the configuration object (after it is loaded) and initiate the model (for | |
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or | |
automatically loaded: | |
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying | |
model's `__init__` method (we assume all relevant updates to the configuration have already been | |
done). | |
- If a configuration is not provided, `kwargs` are first passed to the configuration class | |
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds | |
to a configuration attribute is used to override said attribute with the supplied `kwargs` value. | |
Remaining keys that do not correspond to any configuration attribute are passed to the underlying | |
model's `__init__` function. | |
Examples: | |
```python | |
>>> from diffusers import FlaxUNet2DConditionModel | |
>>> # Download model and configuration from huggingface.co and cache. | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") | |
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). | |
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") | |
``` | |
If you get the error message below, you need to finetune the weights for your downstream task: | |
```bash | |
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: | |
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated | |
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. | |
``` | |
""" | |
config = kwargs.pop("config", None) | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
from_pt = kwargs.pop("from_pt", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", False) | |
token = kwargs.pop("token", None) | |
revision = kwargs.pop("revision", None) | |
subfolder = kwargs.pop("subfolder", None) | |
user_agent = { | |
"diffusers": __version__, | |
"file_type": "model", | |
"framework": "flax", | |
} | |
# Load config if we don't provide one | |
if config is None: | |
config, unused_kwargs = cls.load_config( | |
pretrained_model_name_or_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
**kwargs, | |
) | |
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs) | |
# Load model | |
pretrained_path_with_subfolder = ( | |
pretrained_model_name_or_path | |
if subfolder is None | |
else os.path.join(pretrained_model_name_or_path, subfolder) | |
) | |
if os.path.isdir(pretrained_path_with_subfolder): | |
if from_pt: | |
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): | |
raise EnvironmentError( | |
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " | |
) | |
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) | |
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): | |
# Load from a Flax checkpoint | |
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) | |
# Check if pytorch weights exist instead | |
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): | |
raise EnvironmentError( | |
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" | |
" using `from_pt=True`." | |
) | |
else: | |
raise EnvironmentError( | |
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " | |
f"{pretrained_path_with_subfolder}." | |
) | |
else: | |
try: | |
model_file = hf_hub_download( | |
pretrained_model_name_or_path, | |
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
local_files_only=local_files_only, | |
token=token, | |
user_agent=user_agent, | |
subfolder=subfolder, | |
revision=revision, | |
) | |
except RepositoryNotFoundError: | |
raise EnvironmentError( | |
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " | |
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " | |
"token having permission to this repo with `token` or log in with `huggingface-cli " | |
"login`." | |
) | |
except RevisionNotFoundError: | |
raise EnvironmentError( | |
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " | |
"this model name. Check the model page at " | |
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." | |
) | |
except EntryNotFoundError: | |
raise EnvironmentError( | |
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." | |
) | |
except HTTPError as err: | |
raise EnvironmentError( | |
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" | |
f"{err}" | |
) | |
except ValueError: | |
raise EnvironmentError( | |
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" | |
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" | |
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" | |
" internet connection or see how to run the library in offline mode at" | |
" 'https://huggingface.co/docs/transformers/installation#offline-mode'." | |
) | |
except EnvironmentError: | |
raise EnvironmentError( | |
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " | |
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " | |
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " | |
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." | |
) | |
if from_pt: | |
if is_torch_available(): | |
from .modeling_utils import load_state_dict | |
else: | |
raise EnvironmentError( | |
"Can't load the model in PyTorch format because PyTorch is not installed. " | |
"Please, install PyTorch or use native Flax weights." | |
) | |
# Step 1: Get the pytorch file | |
pytorch_model_file = load_state_dict(model_file) | |
# Step 2: Convert the weights | |
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) | |
else: | |
try: | |
with open(model_file, "rb") as state_f: | |
state = from_bytes(cls, state_f.read()) | |
except (UnpicklingError, msgpack.exceptions.ExtraData) as e: | |
try: | |
with open(model_file) as f: | |
if f.read().startswith("version"): | |
raise OSError( | |
"You seem to have cloned a repository without having git-lfs installed. Please" | |
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | |
" folder you cloned." | |
) | |
else: | |
raise ValueError from e | |
except (UnicodeDecodeError, ValueError): | |
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") | |
# make sure all arrays are stored as jnp.ndarray | |
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: | |
# https://github.com/google/flax/issues/1261 | |
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) | |
# flatten dicts | |
state = flatten_dict(state) | |
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) | |
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) | |
shape_state = flatten_dict(unfreeze(params_shape_tree)) | |
missing_keys = required_params - set(state.keys()) | |
unexpected_keys = set(state.keys()) - required_params | |
if missing_keys: | |
logger.warning( | |
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " | |
"Make sure to call model.init_weights to initialize the missing weights." | |
) | |
cls._missing_keys = missing_keys | |
for key in state.keys(): | |
if key in shape_state and state[key].shape != shape_state[key].shape: | |
raise ValueError( | |
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " | |
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " | |
) | |
# remove unexpected keys to not be saved again | |
for unexpected_key in unexpected_keys: | |
del state[unexpected_key] | |
if len(unexpected_keys) > 0: | |
logger.warning( | |
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | |
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | |
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | |
" with another architecture." | |
) | |
else: | |
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" | |
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
else: | |
logger.info( | |
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" | |
f" was trained on, you can already use {model.__class__.__name__} for predictions without further" | |
" training." | |
) | |
return model, unflatten_dict(state) | |
def save_pretrained( | |
self, | |
save_directory: Union[str, os.PathLike], | |
params: Union[Dict, FrozenDict], | |
is_main_process: bool = True, | |
push_to_hub: bool = False, | |
**kwargs, | |
): | |
""" | |
Save a model and its configuration file to a directory so that it can be reloaded using the | |
[`~FlaxModelMixin.from_pretrained`] class method. | |
Arguments: | |
save_directory (`str` or `os.PathLike`): | |
Directory to save a model and its configuration file to. Will be created if it doesn't exist. | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
is_main_process (`bool`, *optional*, defaults to `True`): | |
Whether the process calling this is the main process or not. Useful during distributed training and you | |
need to call this function on all processes. In this case, set `is_main_process=True` only on the main | |
process to avoid race conditions. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the | |
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your | |
namespace). | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | |
""" | |
if os.path.isfile(save_directory): | |
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
return | |
os.makedirs(save_directory, exist_ok=True) | |
if push_to_hub: | |
commit_message = kwargs.pop("commit_message", None) | |
private = kwargs.pop("private", False) | |
create_pr = kwargs.pop("create_pr", False) | |
token = kwargs.pop("token", None) | |
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | |
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id | |
model_to_save = self | |
# Attach architecture to the config | |
# Save the config | |
if is_main_process: | |
model_to_save.save_config(save_directory) | |
# save model | |
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) | |
with open(output_model_file, "wb") as f: | |
model_bytes = to_bytes(params) | |
f.write(model_bytes) | |
logger.info(f"Model weights saved in {output_model_file}") | |
if push_to_hub: | |
self._upload_folder( | |
save_directory, | |
repo_id, | |
token=token, | |
commit_message=commit_message, | |
create_pr=create_pr, | |
) | |
from ..utils import deprecate | |
from .unets.unet_1d_blocks import ( | |
AttnDownBlock1D, | |
AttnUpBlock1D, | |
DownBlock1D, | |
DownBlock1DNoSkip, | |
DownResnetBlock1D, | |
Downsample1d, | |
MidResTemporalBlock1D, | |
OutConv1DBlock, | |
OutValueFunctionBlock, | |
ResConvBlock, | |
SelfAttention1d, | |
UNetMidBlock1D, | |
UpBlock1D, | |
UpBlock1DNoSkip, | |
UpResnetBlock1D, | |
Upsample1d, | |
ValueFunctionMidBlock1D, | |
) | |
class DownResnetBlock1D(DownResnetBlock1D): | |
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead." | |
deprecate("DownResnetBlock1D", "0.29", deprecation_message) | |
class UpResnetBlock1D(UpResnetBlock1D): | |
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead." | |
deprecate("UpResnetBlock1D", "0.29", deprecation_message) | |
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D): | |
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead." | |
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message) | |
class OutConv1DBlock(OutConv1DBlock): | |
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead." | |
deprecate("OutConv1DBlock", "0.29", deprecation_message) | |
class OutValueFunctionBlock(OutValueFunctionBlock): | |
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead." | |
deprecate("OutValueFunctionBlock", "0.29", deprecation_message) | |
class Downsample1d(Downsample1d): | |
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead." | |
deprecate("Downsample1d", "0.29", deprecation_message) | |
class Upsample1d(Upsample1d): | |
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead." | |
deprecate("Upsample1d", "0.29", deprecation_message) | |
class SelfAttention1d(SelfAttention1d): | |
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead." | |
deprecate("SelfAttention1d", "0.29", deprecation_message) | |
class ResConvBlock(ResConvBlock): | |
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead." | |
deprecate("ResConvBlock", "0.29", deprecation_message) | |
class UNetMidBlock1D(UNetMidBlock1D): | |
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead." | |
deprecate("UNetMidBlock1D", "0.29", deprecation_message) | |
class AttnDownBlock1D(AttnDownBlock1D): | |
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead." | |
deprecate("AttnDownBlock1D", "0.29", deprecation_message) | |
class DownBlock1D(DownBlock1D): | |
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead." | |
deprecate("DownBlock1D", "0.29", deprecation_message) | |
class DownBlock1DNoSkip(DownBlock1DNoSkip): | |
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead." | |
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message) | |
class AttnUpBlock1D(AttnUpBlock1D): | |
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead." | |
deprecate("AttnUpBlock1D", "0.29", deprecation_message) | |
class UpBlock1D(UpBlock1D): | |
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead." | |
deprecate("UpBlock1D", "0.29", deprecation_message) | |
class UpBlock1DNoSkip(UpBlock1DNoSkip): | |
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead." | |
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message) | |
class MidResTemporalBlock1D(MidResTemporalBlock1D): | |
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead." | |
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message) | |
def get_down_block( | |
down_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
add_downsample: bool, | |
): | |
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead." | |
deprecate("get_down_block", "0.29", deprecation_message) | |
from .unets.unet_1d_blocks import get_down_block | |
return get_down_block( | |
down_block_type=down_block_type, | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
) | |
def get_up_block( | |
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool | |
): | |
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead." | |
deprecate("get_up_block", "0.29", deprecation_message) | |
from .unets.unet_1d_blocks import get_up_block | |
return get_up_block( | |
up_block_type=up_block_type, | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
) | |
def get_mid_block( | |
mid_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
mid_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
add_downsample: bool, | |
): | |
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead." | |
deprecate("get_mid_block", "0.29", deprecation_message) | |
from .unets.unet_1d_blocks import get_mid_block | |
return get_mid_block( | |
mid_block_type=mid_block_type, | |
num_layers=num_layers, | |
in_channels=in_channels, | |
mid_channels=mid_channels, | |
out_channels=out_channels, | |
embed_dim=embed_dim, | |
add_downsample=add_downsample, | |
) | |
def get_out_block( | |
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int | |
): | |
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead." | |
deprecate("get_out_block", "0.29", deprecation_message) | |
from .unets.unet_1d_blocks import get_out_block | |
return get_out_block( | |
out_block_type=out_block_type, | |
num_groups_out=num_groups_out, | |
embed_dim=embed_dim, | |
out_channels=out_channels, | |
act_fn=act_fn, | |
fc_dim=fc_dim, | |
) | |
import math | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from torch import nn | |
from ..utils import USE_PEFT_BACKEND, deprecate | |
from .activations import get_activation | |
from .attention_processor import Attention | |
from .lora import LoRACompatibleLinear | |
def get_timestep_embedding( | |
timesteps: torch.Tensor, | |
embedding_dim: int, | |
flip_sin_to_cos: bool = False, | |
downscale_freq_shift: float = 1, | |
scale: float = 1, | |
max_period: int = 10000, | |
): | |
""" | |
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
:param timesteps: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the | |
embeddings. :return: an [N x dim] Tensor of positional embeddings. | |
""" | |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
half_dim = embedding_dim // 2 | |
exponent = -math.log(max_period) * torch.arange( | |
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
) | |
exponent = exponent / (half_dim - downscale_freq_shift) | |
emb = torch.exp(exponent) | |
emb = timesteps[:, None].float() * emb[None, :] | |
# scale embeddings | |
emb = scale * emb | |
# concat sine and cosine embeddings | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
# flip sine and cosine embeddings | |
if flip_sin_to_cos: | |
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
# zero pad | |
if embedding_dim % 2 == 1: | |
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
return emb | |
def get_2d_sincos_pos_embed( | |
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 | |
): | |
""" | |
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
if isinstance(grid_size, int): | |
grid_size = (grid_size, grid_size) | |
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token and extra_tokens > 0: | |
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
""" | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
omega = np.arange(embed_dim // 2, dtype=np.float64) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class PatchEmbed(nn.Module): | |
"""2D Image to Patch Embedding""" | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=1, | |
): | |
super().__init__() | |
num_patches = (height // patch_size) * (width // patch_size) | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size = patch_size | |
# See: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
self.interpolation_scale = interpolation_scale | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
) | |
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
def forward(self, latent): | |
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
latent = self.proj(latent) | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC | |
if self.layer_norm: | |
latent = self.norm(latent) | |
# Interpolate positional embeddings if needed. | |
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
if self.height != height or self.width != width: | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim=self.pos_embed.shape[-1], | |
grid_size=(height, width), | |
base_size=self.base_size, | |
interpolation_scale=self.interpolation_scale, | |
) | |
pos_embed = torch.from_numpy(pos_embed) | |
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
else: | |
pos_embed = self.pos_embed | |
return (latent + pos_embed).to(latent.dtype) | |
class TimestepEmbedding(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
time_embed_dim: int, | |
act_fn: str = "silu", | |
out_dim: int = None, | |
post_act_fn: Optional[str] = None, | |
cond_proj_dim=None, | |
sample_proj_bias=True, | |
): | |
super().__init__() | |
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | |
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) | |
if cond_proj_dim is not None: | |
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
else: | |
self.cond_proj = None | |
self.act = get_activation(act_fn) | |
if out_dim is not None: | |
time_embed_dim_out = out_dim | |
else: | |
time_embed_dim_out = time_embed_dim | |
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
if post_act_fn is None: | |
self.post_act = None | |
else: | |
self.post_act = get_activation(post_act_fn) | |
def forward(self, sample, condition=None): | |
if condition is not None: | |
sample = sample + self.cond_proj(condition) | |
sample = self.linear_1(sample) | |
if self.act is not None: | |
sample = self.act(sample) | |
sample = self.linear_2(sample) | |
if self.post_act is not None: | |
sample = self.post_act(sample) | |
return sample | |
class Timesteps(nn.Module): | |
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): | |
super().__init__() | |
self.num_channels = num_channels | |
self.flip_sin_to_cos = flip_sin_to_cos | |
self.downscale_freq_shift = downscale_freq_shift | |
def forward(self, timesteps): | |
t_emb = get_timestep_embedding( | |
timesteps, | |
self.num_channels, | |
flip_sin_to_cos=self.flip_sin_to_cos, | |
downscale_freq_shift=self.downscale_freq_shift, | |
) | |
return t_emb | |
class GaussianFourierProjection(nn.Module): | |
"""Gaussian Fourier embeddings for noise levels.""" | |
def __init__( | |
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False | |
): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
self.log = log | |
self.flip_sin_to_cos = flip_sin_to_cos | |
if set_W_to_weight: | |
# to delete later | |
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
self.weight = self.W | |
def forward(self, x): | |
if self.log: | |
x = torch.log(x) | |
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi | |
if self.flip_sin_to_cos: | |
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) | |
else: | |
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) | |
return out | |
class SinusoidalPositionalEmbedding(nn.Module): | |
"""Apply positional information to a sequence of embeddings. | |
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to | |
them | |
Args: | |
embed_dim: (int): Dimension of the positional embedding. | |
max_seq_length: Maximum sequence length to apply positional embeddings | |
""" | |
def __init__(self, embed_dim: int, max_seq_length: int = 32): | |
super().__init__() | |
position = torch.arange(max_seq_length).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) | |
pe = torch.zeros(1, max_seq_length, embed_dim) | |
pe[0, :, 0::2] = torch.sin(position * div_term) | |
pe[0, :, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
_, seq_length, _ = x.shape | |
x = x + self.pe[:, :seq_length] | |
return x | |
class ImagePositionalEmbeddings(nn.Module): | |
""" | |
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the | |
height and width of the latent space. | |
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 | |
For VQ-diffusion: | |
Output vector embeddings are used as input for the transformer. | |
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. | |
Args: | |
num_embed (`int`): | |
Number of embeddings for the latent pixels embeddings. | |
height (`int`): | |
Height of the latent image i.e. the number of height embeddings. | |
width (`int`): | |
Width of the latent image i.e. the number of width embeddings. | |
embed_dim (`int`): | |
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. | |
""" | |
def __init__( | |
self, | |
num_embed: int, | |
height: int, | |
width: int, | |
embed_dim: int, | |
): | |
super().__init__() | |
self.height = height | |
self.width = width | |
self.num_embed = num_embed | |
self.embed_dim = embed_dim | |
self.emb = nn.Embedding(self.num_embed, embed_dim) | |
self.height_emb = nn.Embedding(self.height, embed_dim) | |
self.width_emb = nn.Embedding(self.width, embed_dim) | |
def forward(self, index): | |
emb = self.emb(index) | |
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) | |
# 1 x H x D -> 1 x H x 1 x D | |
height_emb = height_emb.unsqueeze(2) | |
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) | |
# 1 x W x D -> 1 x 1 x W x D | |
width_emb = width_emb.unsqueeze(1) | |
pos_emb = height_emb + width_emb | |
# 1 x H x W x D -> 1 x L xD | |
pos_emb = pos_emb.view(1, self.height * self.width, -1) | |
emb = emb + pos_emb[:, : emb.shape[1], :] | |
return emb | |
class LabelEmbedding(nn.Module): | |
""" | |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
Args: | |
num_classes (`int`): The number of classes. | |
hidden_size (`int`): The size of the vector embeddings. | |
dropout_prob (`float`): The probability of dropping a label. | |
""" | |
def __init__(self, num_classes, hidden_size, dropout_prob): | |
super().__init__() | |
use_cfg_embedding = dropout_prob > 0 | |
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) | |
self.num_classes = num_classes | |
self.dropout_prob = dropout_prob | |
def token_drop(self, labels, force_drop_ids=None): | |
""" | |
Drops labels to enable classifier-free guidance. | |
""" | |
if force_drop_ids is None: | |
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob | |
else: | |
drop_ids = torch.tensor(force_drop_ids == 1) | |
labels = torch.where(drop_ids, self.num_classes, labels) | |
return labels | |
def forward(self, labels: torch.LongTensor, force_drop_ids=None): | |
use_dropout = self.dropout_prob > 0 | |
if (self.training and use_dropout) or (force_drop_ids is not None): | |
labels = self.token_drop(labels, force_drop_ids) | |
embeddings = self.embedding_table(labels) | |
return embeddings | |
class TextImageProjection(nn.Module): | |
def __init__( | |
self, | |
text_embed_dim: int = 1024, | |
image_embed_dim: int = 768, | |
cross_attention_dim: int = 768, | |
num_image_text_embeds: int = 10, | |
): | |
super().__init__() | |
self.num_image_text_embeds = num_image_text_embeds | |
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) | |
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): | |
batch_size = text_embeds.shape[0] | |
# image | |
image_text_embeds = self.image_embeds(image_embeds) | |
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
# text | |
text_embeds = self.text_proj(text_embeds) | |
return torch.cat([image_text_embeds, text_embeds], dim=1) | |
class ImageProjection(nn.Module): | |
def __init__( | |
self, | |
image_embed_dim: int = 768, | |
cross_attention_dim: int = 768, | |
num_image_text_embeds: int = 32, | |
): | |
super().__init__() | |
self.num_image_text_embeds = num_image_text_embeds | |
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds: torch.FloatTensor): | |
batch_size = image_embeds.shape[0] | |
# image | |
image_embeds = self.image_embeds(image_embeds) | |
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
image_embeds = self.norm(image_embeds) | |
return image_embeds | |
class IPAdapterFullImageProjection(nn.Module): | |
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): | |
super().__init__() | |
from .attention import FeedForward | |
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds: torch.FloatTensor): | |
return self.norm(self.ff(image_embeds)) | |
class CombinedTimestepLabelEmbeddings(nn.Module): | |
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) | |
def forward(self, timestep, class_labels, hidden_dtype=None): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
class_labels = self.class_embedder(class_labels) # (N, D) | |
conditioning = timesteps_emb + class_labels # (N, D) | |
return conditioning | |
class TextTimeEmbedding(nn.Module): | |
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(encoder_dim) | |
self.pool = AttentionPooling(num_heads, encoder_dim) | |
self.proj = nn.Linear(encoder_dim, time_embed_dim) | |
self.norm2 = nn.LayerNorm(time_embed_dim) | |
def forward(self, hidden_states): | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = self.pool(hidden_states) | |
hidden_states = self.proj(hidden_states) | |
hidden_states = self.norm2(hidden_states) | |
return hidden_states | |
class TextImageTimeEmbedding(nn.Module): | |
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) | |
self.text_norm = nn.LayerNorm(time_embed_dim) | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): | |
# text | |
time_text_embeds = self.text_proj(text_embeds) | |
time_text_embeds = self.text_norm(time_text_embeds) | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
return time_image_embeds + time_text_embeds | |
class ImageTimeEmbedding(nn.Module): | |
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
self.image_norm = nn.LayerNorm(time_embed_dim) | |
def forward(self, image_embeds: torch.FloatTensor): | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
time_image_embeds = self.image_norm(time_image_embeds) | |
return time_image_embeds | |
class ImageHintTimeEmbedding(nn.Module): | |
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
super().__init__() | |
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
self.image_norm = nn.LayerNorm(time_embed_dim) | |
self.input_hint_block = nn.Sequential( | |
nn.Conv2d(3, 16, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(16, 16, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(16, 32, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(32, 32, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(32, 96, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(96, 96, 3, padding=1), | |
nn.SiLU(), | |
nn.Conv2d(96, 256, 3, padding=1, stride=2), | |
nn.SiLU(), | |
nn.Conv2d(256, 4, 3, padding=1), | |
) | |
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): | |
# image | |
time_image_embeds = self.image_proj(image_embeds) | |
time_image_embeds = self.image_norm(time_image_embeds) | |
hint = self.input_hint_block(hint) | |
return time_image_embeds, hint | |
class AttentionPooling(nn.Module): | |
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 | |
def __init__(self, num_heads, embed_dim, dtype=None): | |
super().__init__() | |
self.dtype = dtype | |
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
self.num_heads = num_heads | |
self.dim_per_head = embed_dim // self.num_heads | |
def forward(self, x): | |
bs, length, width = x.size() | |
def shape(x): | |
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) | |
x = x.view(bs, -1, self.num_heads, self.dim_per_head) | |
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) | |
x = x.transpose(1, 2) | |
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) | |
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) | |
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) | |
x = x.transpose(1, 2) | |
return x | |
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) | |
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) | |
# (bs*n_heads, class_token_length, dim_per_head) | |
q = shape(self.q_proj(class_token)) | |
# (bs*n_heads, length+class_token_length, dim_per_head) | |
k = shape(self.k_proj(x)) | |
v = shape(self.v_proj(x)) | |
# (bs*n_heads, class_token_length, length+class_token_length): | |
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) | |
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards | |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
# (bs*n_heads, dim_per_head, class_token_length) | |
a = torch.einsum("bts,bcs->bct", weight, v) | |
# (bs, length+1, width) | |
a = a.reshape(bs, -1, 1).transpose(1, 2) | |
return a[:, 0, :] # cls_token | |
def get_fourier_embeds_from_boundingbox(embed_dim, box): | |
""" | |
Args: | |
embed_dim: int | |
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline | |
Returns: | |
[B x N x embed_dim] tensor of positional embeddings | |
""" | |
batch_size, num_boxes = box.shape[:2] | |
emb = 100 ** (torch.arange(embed_dim) / embed_dim) | |
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) | |
emb = emb * box.unsqueeze(-1) | |
emb = torch.stack((emb.sin(), emb.cos()), dim=-1) | |
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) | |
return emb | |
class GLIGENTextBoundingboxProjection(nn.Module): | |
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): | |
super().__init__() | |
self.positive_len = positive_len | |
self.out_dim = out_dim | |
self.fourier_embedder_dim = fourier_freqs | |
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy | |
if isinstance(out_dim, tuple): | |
out_dim = out_dim[0] | |
if feature_type == "text-only": | |
self.linears = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
elif feature_type == "text-image": | |
self.linears_text = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.linears_image = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
def forward( | |
self, | |
boxes, | |
masks, | |
positive_embeddings=None, | |
phrases_masks=None, | |
image_masks=None, | |
phrases_embeddings=None, | |
image_embeddings=None, | |
): | |
masks = masks.unsqueeze(-1) | |
# embedding position (it may includes padding as placeholder) | |
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C | |
# learnable null embedding | |
xyxy_null = self.null_position_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
# positionet with text only information | |
if positive_embeddings is not None: | |
# learnable null embedding | |
positive_null = self.null_positive_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null | |
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) | |
# positionet with text and image infomation | |
else: | |
phrases_masks = phrases_masks.unsqueeze(-1) | |
image_masks = image_masks.unsqueeze(-1) | |
# learnable null embedding | |
text_null = self.null_text_feature.view(1, 1, -1) | |
image_null = self.null_image_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null | |
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null | |
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) | |
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) | |
objs = torch.cat([objs_text, objs_image], dim=1) | |
return objs | |
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | |
""" | |
For PixArt-Alpha. | |
Reference: | |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
""" | |
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
super().__init__() | |
self.outdim = size_emb_dim | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.use_additional_conditions = use_additional_conditions | |
if use_additional_conditions: | |
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
if self.use_additional_conditions: | |
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) | |
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) | |
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) | |
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) | |
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) | |
else: | |
conditioning = timesteps_emb | |
return conditioning | |
class PixArtAlphaTextProjection(nn.Module): | |
""" | |
Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
""" | |
def __init__(self, in_features, hidden_size, num_tokens=120): | |
super().__init__() | |
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
self.act_1 = nn.GELU(approximate="tanh") | |
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) | |
def forward(self, caption): | |
hidden_states = self.linear_1(caption) | |
hidden_states = self.act_1(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
class IPAdapterPlusImageProjection(nn.Module): | |
"""Resampler of IP-Adapter Plus. | |
Args: | |
---- | |
embed_dims (int): The feature dimension. Defaults to 768. | |
output_dims (int): The number of output channels, that is the same | |
number of the channels in the | |
`unet.config.cross_attention_dim`. Defaults to 1024. | |
hidden_dims (int): The number of hidden channels. Defaults to 1280. | |
depth (int): The number of blocks. Defaults to 8. | |
dim_head (int): The number of head channels. Defaults to 64. | |
heads (int): Parallel attention heads. Defaults to 16. | |
num_queries (int): The number of queries. Defaults to 8. | |
ffn_ratio (float): The expansion ratio of feedforward network hidden | |
layer channels. Defaults to 4. | |
""" | |
def __init__( | |
self, | |
embed_dims: int = 768, | |
output_dims: int = 1024, | |
hidden_dims: int = 1280, | |
depth: int = 4, | |
dim_head: int = 64, | |
heads: int = 16, | |
num_queries: int = 8, | |
ffn_ratio: float = 4, | |
) -> None: | |
super().__init__() | |
from .attention import FeedForward # Lazy import to avoid circular import | |
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) | |
self.proj_in = nn.Linear(embed_dims, hidden_dims) | |
self.proj_out = nn.Linear(hidden_dims, output_dims) | |
self.norm_out = nn.LayerNorm(output_dims) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
nn.LayerNorm(hidden_dims), | |
nn.LayerNorm(hidden_dims), | |
Attention( | |
query_dim=hidden_dims, | |
dim_head=dim_head, | |
heads=heads, | |
out_bias=False, | |
), | |
nn.Sequential( | |
nn.LayerNorm(hidden_dims), | |
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | |
), | |
] | |
) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward pass. | |
Args: | |
---- | |
x (torch.Tensor): Input Tensor. | |
Returns: | |
------- | |
torch.Tensor: Output Tensor. | |
""" | |
latents = self.latents.repeat(x.size(0), 1, 1) | |
x = self.proj_in(x) | |
for ln0, ln1, attn, ff in self.layers: | |
residual = latents | |
encoder_hidden_states = ln0(x) | |
latents = ln1(latents) | |
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) | |
latents = attn(latents, encoder_hidden_states) + residual | |
latents = ff(latents) + latents | |
latents = self.proj_out(latents) | |
return self.norm_out(latents) | |
class MultiIPAdapterImageProjection(nn.Module): | |
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): | |
super().__init__() | |
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) | |
def forward(self, image_embeds: List[torch.FloatTensor]): | |
projected_image_embeds = [] | |
# currently, we accept `image_embeds` as | |
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] | |
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] | |
if not isinstance(image_embeds, list): | |
deprecation_message = ( | |
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." | |
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning." | |
) | |
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) | |
image_embeds = [image_embeds.unsqueeze(1)] | |
if len(image_embeds) != len(self.image_projection_layers): | |
raise ValueError( | |
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" | |
) | |
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): | |
batch_size, num_images = image_embed.shape[0], image_embed.shape[1] | |
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) | |
image_embed = image_projection_layer(image_embed) | |
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) | |
projected_image_embeds.append(image_embed) | |
return projected_image_embeds | |
from ..utils import deprecate | |
from .transformers.t5_film_transformer import ( | |
DecoderLayer, | |
NewGELUActivation, | |
T5DenseGatedActDense, | |
T5FilmDecoder, | |
T5FiLMLayer, | |
T5LayerCrossAttention, | |
T5LayerFFCond, | |
T5LayerNorm, | |
T5LayerSelfAttentionCond, | |
) | |
class T5FilmDecoder(T5FilmDecoder): | |
deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead." | |
deprecate("T5FilmDecoder", "0.29", deprecation_message) | |
class DecoderLayer(DecoderLayer): | |
deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead." | |
deprecate("DecoderLayer", "0.29", deprecation_message) | |
class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond): | |
deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead." | |
deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message) | |
class T5LayerCrossAttention(T5LayerCrossAttention): | |
deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead." | |
deprecate("T5LayerCrossAttention", "0.29", deprecation_message) | |
class T5LayerFFCond(T5LayerFFCond): | |
deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead." | |
deprecate("T5LayerFFCond", "0.29", deprecation_message) | |
class T5DenseGatedActDense(T5DenseGatedActDense): | |
deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead." | |
deprecate("T5DenseGatedActDense", "0.29", deprecation_message) | |
class T5LayerNorm(T5LayerNorm): | |
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead." | |
deprecate("T5LayerNorm", "0.29", deprecation_message) | |
class NewGELUActivation(NewGELUActivation): | |
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead." | |
deprecate("NewGELUActivation", "0.29", deprecation_message) | |
class T5FiLMLayer(T5FiLMLayer): | |
deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead." | |
deprecate("T5FiLMLayer", "0.29", deprecation_message) | |
from ...utils import is_flax_available, is_torch_available | |
if is_torch_available(): | |
from .unet_1d import UNet1DModel | |
from .unet_2d import UNet2DModel | |
from .unet_2d_condition import UNet2DConditionModel | |
from .unet_3d_condition import UNet3DConditionModel | |
from .unet_i2vgen_xl import I2VGenXLUNet | |
from .unet_kandinsky3 import Kandinsky3UNet | |
from .unet_motion_model import MotionAdapter, UNetMotionModel | |
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel | |
from .unet_stable_cascade import StableCascadeUNet | |
from .uvit_2d import UVit2DModel | |
if is_flax_available(): | |
from .unet_2d_condition_flax import FlaxUNet2DConditionModel | |
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from ...utils import is_torch_version | |
from ...utils.torch_utils import apply_freeu | |
from ..attention import Attention | |
from ..resnet import ( | |
Downsample2D, | |
ResnetBlock2D, | |
SpatioTemporalResBlock, | |
TemporalConvLayer, | |
Upsample2D, | |
) | |
from ..transformers.dual_transformer_2d import DualTransformer2DModel | |
from ..transformers.transformer_2d import Transformer2DModel | |
from ..transformers.transformer_temporal import ( | |
TransformerSpatioTemporalModel, | |
TransformerTemporalModel, | |
) | |
def get_down_block( | |
down_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
add_downsample: bool, | |
resnet_eps: float, | |
resnet_act_fn: str, | |
num_attention_heads: int, | |
resnet_groups: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
downsample_padding: Optional[int] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = True, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
temporal_num_attention_heads: int = 8, | |
temporal_max_seq_length: int = 32, | |
transformer_layers_per_block: int = 1, | |
) -> Union[ | |
"DownBlock3D", | |
"CrossAttnDownBlock3D", | |
"DownBlockMotion", | |
"CrossAttnDownBlockMotion", | |
"DownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
]: | |
if down_block_type == "DownBlock3D": | |
return DownBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
) | |
elif down_block_type == "CrossAttnDownBlock3D": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") | |
return CrossAttnDownBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
) | |
if down_block_type == "DownBlockMotion": | |
return DownBlockMotion( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
temporal_num_attention_heads=temporal_num_attention_heads, | |
temporal_max_seq_length=temporal_max_seq_length, | |
) | |
elif down_block_type == "CrossAttnDownBlockMotion": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") | |
return CrossAttnDownBlockMotion( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
temporal_num_attention_heads=temporal_num_attention_heads, | |
temporal_max_seq_length=temporal_max_seq_length, | |
) | |
elif down_block_type == "DownBlockSpatioTemporal": | |
# added for SDV | |
return DownBlockSpatioTemporal( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
) | |
elif down_block_type == "CrossAttnDownBlockSpatioTemporal": | |
# added for SDV | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") | |
return CrossAttnDownBlockSpatioTemporal( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
num_layers=num_layers, | |
transformer_layers_per_block=transformer_layers_per_block, | |
add_downsample=add_downsample, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
) | |
raise ValueError(f"{down_block_type} does not exist.") | |
def get_up_block( | |
up_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
add_upsample: bool, | |
resnet_eps: float, | |
resnet_act_fn: str, | |
num_attention_heads: int, | |
resolution_idx: Optional[int] = None, | |
resnet_groups: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = True, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
temporal_num_attention_heads: int = 8, | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_max_seq_length: int = 32, | |
transformer_layers_per_block: int = 1, | |
dropout: float = 0.0, | |
) -> Union[ | |
"UpBlock3D", | |
"CrossAttnUpBlock3D", | |
"UpBlockMotion", | |
"CrossAttnUpBlockMotion", | |
"UpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
]: | |
if up_block_type == "UpBlock3D": | |
return UpBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
resolution_idx=resolution_idx, | |
) | |
elif up_block_type == "CrossAttnUpBlock3D": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") | |
return CrossAttnUpBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
resolution_idx=resolution_idx, | |
) | |
if up_block_type == "UpBlockMotion": | |
return UpBlockMotion( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
resolution_idx=resolution_idx, | |
temporal_num_attention_heads=temporal_num_attention_heads, | |
temporal_max_seq_length=temporal_max_seq_length, | |
) | |
elif up_block_type == "CrossAttnUpBlockMotion": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") | |
return CrossAttnUpBlockMotion( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
resolution_idx=resolution_idx, | |
temporal_num_attention_heads=temporal_num_attention_heads, | |
temporal_max_seq_length=temporal_max_seq_length, | |
) | |
elif up_block_type == "UpBlockSpatioTemporal": | |
# added for SDV | |
return UpBlockSpatioTemporal( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
resolution_idx=resolution_idx, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "CrossAttnUpBlockSpatioTemporal": | |
# added for SDV | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") | |
return CrossAttnUpBlockSpatioTemporal( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
num_layers=num_layers, | |
transformer_layers_per_block=transformer_layers_per_block, | |
add_upsample=add_upsample, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
resolution_idx=resolution_idx, | |
) | |
raise ValueError(f"{up_block_type} does not exist.") | |
class UNetMidBlock3DCrossAttn(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
output_scale_factor: float = 1.0, | |
cross_attention_dim: int = 1280, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = True, | |
upcast_attention: bool = False, | |
): | |
super().__init__() | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
# there is always at least one resnet | |
resnets = [ | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
] | |
temp_convs = [ | |
TemporalConvLayer( | |
in_channels, | |
in_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
] | |
attentions = [] | |
temp_attentions = [] | |
for _ in range(num_layers): | |
attentions.append( | |
Transformer2DModel( | |
in_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=in_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
) | |
) | |
temp_attentions.append( | |
TransformerTemporalModel( | |
in_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=in_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
temp_convs.append( | |
TemporalConvLayer( | |
in_channels, | |
in_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.temp_convs = nn.ModuleList(temp_convs) | |
self.attentions = nn.ModuleList(attentions) | |
self.temp_attentions = nn.ModuleList(temp_attentions) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
) -> torch.FloatTensor: | |
hidden_states = self.resnets[0](hidden_states, temb) | |
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) | |
for attn, temp_attn, resnet, temp_conv in zip( | |
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] | |
): | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
hidden_states = temp_attn( | |
hidden_states, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = temp_conv(hidden_states, num_frames=num_frames) | |
return hidden_states | |
class CrossAttnDownBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
output_scale_factor: float = 1.0, | |
downsample_padding: int = 1, | |
add_downsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
temp_attentions = [] | |
temp_convs = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
temp_convs.append( | |
TemporalConvLayer( | |
out_channels, | |
out_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
attentions.append( | |
Transformer2DModel( | |
out_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
) | |
) | |
temp_attentions.append( | |
TransformerTemporalModel( | |
out_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.temp_convs = nn.ModuleList(temp_convs) | |
self.attentions = nn.ModuleList(attentions) | |
self.temp_attentions = nn.ModuleList(temp_attentions) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
# TODO(Patrick, William) - attention mask is not used | |
output_states = () | |
for resnet, temp_conv, attn, temp_attn in zip( | |
self.resnets, self.temp_convs, self.attentions, self.temp_attentions | |
): | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = temp_conv(hidden_states, num_frames=num_frames) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
hidden_states = temp_attn( | |
hidden_states, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class DownBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
downsample_padding: int = 1, | |
): | |
super().__init__() | |
resnets = [] | |
temp_convs = [] | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
temp_convs.append( | |
TemporalConvLayer( | |
out_channels, | |
out_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.temp_convs = nn.ModuleList(temp_convs) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
for resnet, temp_conv in zip(self.resnets, self.temp_convs): | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = temp_conv(hidden_states, num_frames=num_frames) | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class CrossAttnUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resolution_idx: Optional[int] = None, | |
): | |
super().__init__() | |
resnets = [] | |
temp_convs = [] | |
attentions = [] | |
temp_attentions = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
temp_convs.append( | |
TemporalConvLayer( | |
out_channels, | |
out_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
attentions.append( | |
Transformer2DModel( | |
out_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
) | |
) | |
temp_attentions.append( | |
TransformerTemporalModel( | |
out_channels // num_attention_heads, | |
num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.temp_convs = nn.ModuleList(temp_convs) | |
self.attentions = nn.ModuleList(attentions) | |
self.temp_attentions = nn.ModuleList(temp_attentions) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
) -> torch.FloatTensor: | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
# TODO(Patrick, William) - attention mask is not used | |
for resnet, temp_conv, attn, temp_attn in zip( | |
self.resnets, self.temp_convs, self.attentions, self.temp_attentions | |
): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
hidden_states, res_hidden_states = apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_hidden_states, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = temp_conv(hidden_states, num_frames=num_frames) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
hidden_states = temp_attn( | |
hidden_states, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
class UpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
prev_output_channel: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
resolution_idx: Optional[int] = None, | |
): | |
super().__init__() | |
resnets = [] | |
temp_convs = [] | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
temp_convs.append( | |
TemporalConvLayer( | |
out_channels, | |
out_channels, | |
dropout=0.1, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.temp_convs = nn.ModuleList(temp_convs) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
upsample_size: Optional[int] = None, | |
num_frames: int = 1, | |
) -> torch.FloatTensor: | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
for resnet, temp_conv in zip(self.resnets, self.temp_convs): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
hidden_states, res_hidden_states = apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_hidden_states, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = temp_conv(hidden_states, num_frames=num_frames) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
class DownBlockMotion(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
downsample_padding: int = 1, | |
temporal_num_attention_heads: int = 1, | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_max_seq_length: int = 32, | |
): | |
super().__init__() | |
resnets = [] | |
motion_modules = [] | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
motion_modules.append( | |
TransformerTemporalModel( | |
num_attention_heads=temporal_num_attention_heads, | |
in_channels=out_channels, | |
norm_num_groups=resnet_groups, | |
cross_attention_dim=temporal_cross_attention_dim, | |
attention_bias=False, | |
activation_fn="geglu", | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=temporal_max_seq_length, | |
attention_head_dim=out_channels // temporal_num_attention_heads, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
num_frames: int = 1, | |
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
blocks = zip(self.resnets, self.motion_modules) | |
for resnet, motion_module in blocks: | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if is_torch_version(">=", "1.11.0"): | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
use_reentrant=False, | |
) | |
else: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb, scale | |
) | |
else: | |
hidden_states = resnet(hidden_states, temb, scale=scale) | |
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] | |
output_states = output_states + (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states, scale=scale) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
class CrossAttnDownBlockMotion(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
output_scale_factor: float = 1.0, | |
downsample_padding: int = 1, | |
add_downsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
attention_type: str = "default", | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_num_attention_heads: int = 8, | |
temporal_max_seq_length: int = 32, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
motion_modules = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
if not dual_cross_attention: | |
attentions.append( | |
Transformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
attention_type=attention_type, | |
) | |
) | |
else: | |
attentions.append( | |
DualTransformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
motion_modules.append( | |
TransformerTemporalModel( | |
num_attention_heads=temporal_num_attention_heads, | |
in_channels=out_channels, | |
norm_num_groups=resnet_groups, | |
cross_attention_dim=temporal_cross_attention_dim, | |
attention_bias=False, | |
activation_fn="geglu", | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=temporal_max_seq_length, | |
attention_head_dim=out_channels // temporal_num_attention_heads, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
additional_residuals: Optional[torch.FloatTensor] = None, | |
): | |
output_states = () | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) | |
for i, (resnet, attn, motion_module) in enumerate(blocks): | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet(hidden_states, temb, scale=lora_scale) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = motion_module( | |
hidden_states, | |
num_frames=num_frames, | |
)[0] | |
# apply additional residuals to the output of the last pair of resnet and attention blocks | |
if i == len(blocks) - 1 and additional_residuals is not None: | |
hidden_states = hidden_states + additional_residuals | |
output_states = output_states + (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states, scale=lora_scale) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
class CrossAttnUpBlockMotion(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
resolution_idx: Optional[int] = None, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
attention_type: str = "default", | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_num_attention_heads: int = 8, | |
temporal_max_seq_length: int = 32, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
motion_modules = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
if not dual_cross_attention: | |
attentions.append( | |
Transformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
attention_type=attention_type, | |
) | |
) | |
else: | |
attentions.append( | |
DualTransformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
motion_modules.append( | |
TransformerTemporalModel( | |
num_attention_heads=temporal_num_attention_heads, | |
in_channels=out_channels, | |
norm_num_groups=resnet_groups, | |
cross_attention_dim=temporal_cross_attention_dim, | |
attention_bias=False, | |
activation_fn="geglu", | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=temporal_max_seq_length, | |
attention_head_dim=out_channels // temporal_num_attention_heads, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
) -> torch.FloatTensor: | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
blocks = zip(self.resnets, self.attentions, self.motion_modules) | |
for resnet, attn, motion_module in blocks: | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
hidden_states, res_hidden_states = apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_hidden_states, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet(hidden_states, temb, scale=lora_scale) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = motion_module( | |
hidden_states, | |
num_frames=num_frames, | |
)[0] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) | |
return hidden_states | |
class UpBlockMotion(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
prev_output_channel: int, | |
out_channels: int, | |
temb_channels: int, | |
resolution_idx: Optional[int] = None, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
temporal_norm_num_groups: int = 32, | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_num_attention_heads: int = 8, | |
temporal_max_seq_length: int = 32, | |
): | |
super().__init__() | |
resnets = [] | |
motion_modules = [] | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
motion_modules.append( | |
TransformerTemporalModel( | |
num_attention_heads=temporal_num_attention_heads, | |
in_channels=out_channels, | |
norm_num_groups=temporal_norm_num_groups, | |
cross_attention_dim=temporal_cross_attention_dim, | |
attention_bias=False, | |
activation_fn="geglu", | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=temporal_max_seq_length, | |
attention_head_dim=out_channels // temporal_num_attention_heads, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
upsample_size=None, | |
scale: float = 1.0, | |
num_frames: int = 1, | |
) -> torch.FloatTensor: | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
blocks = zip(self.resnets, self.motion_modules) | |
for resnet, motion_module in blocks: | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
hidden_states, res_hidden_states = apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_hidden_states, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if is_torch_version(">=", "1.11.0"): | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
use_reentrant=False, | |
) | |
else: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb | |
) | |
else: | |
hidden_states = resnet(hidden_states, temb, scale=scale) | |
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size, scale=scale) | |
return hidden_states | |
class UNetMidBlockCrossAttnMotion(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
num_attention_heads: int = 1, | |
output_scale_factor: float = 1.0, | |
cross_attention_dim: int = 1280, | |
dual_cross_attention: float = False, | |
use_linear_projection: float = False, | |
upcast_attention: float = False, | |
attention_type: str = "default", | |
temporal_num_attention_heads: int = 1, | |
temporal_cross_attention_dim: Optional[int] = None, | |
temporal_max_seq_length: int = 32, | |
): | |
super().__init__() | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
# there is always at least one resnet | |
resnets = [ | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
] | |
attentions = [] | |
motion_modules = [] | |
for _ in range(num_layers): | |
if not dual_cross_attention: | |
attentions.append( | |
Transformer2DModel( | |
num_attention_heads, | |
in_channels // num_attention_heads, | |
in_channels=in_channels, | |
num_layers=transformer_layers_per_block, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
attention_type=attention_type, | |
) | |
) | |
else: | |
attentions.append( | |
DualTransformer2DModel( | |
num_attention_heads, | |
in_channels // num_attention_heads, | |
in_channels=in_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
) | |
) | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
motion_modules.append( | |
TransformerTemporalModel( | |
num_attention_heads=temporal_num_attention_heads, | |
attention_head_dim=in_channels // temporal_num_attention_heads, | |
in_channels=in_channels, | |
norm_num_groups=resnet_groups, | |
cross_attention_dim=temporal_cross_attention_dim, | |
attention_bias=False, | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=temporal_max_seq_length, | |
activation_fn="geglu", | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
num_frames: int = 1, | |
) -> torch.FloatTensor: | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) | |
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) | |
for attn, resnet, motion_module in blocks: | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(motion_module), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = motion_module( | |
hidden_states, | |
num_frames=num_frames, | |
)[0] | |
hidden_states = resnet(hidden_states, temb, scale=lora_scale) | |
return hidden_states | |
class MidBlockTemporalDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
attention_head_dim: int = 512, | |
num_layers: int = 1, | |
upcast_attention: bool = False, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
for i in range(num_layers): | |
input_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=input_channels, | |
out_channels=out_channels, | |
temb_channels=None, | |
eps=1e-6, | |
temporal_eps=1e-5, | |
merge_factor=0.0, | |
merge_strategy="learned", | |
switch_spatial_to_temporal_mix=True, | |
) | |
) | |
attentions.append( | |
Attention( | |
query_dim=in_channels, | |
heads=in_channels // attention_head_dim, | |
dim_head=attention_head_dim, | |
eps=1e-6, | |
upcast_attention=upcast_attention, | |
norm_num_groups=32, | |
bias=True, | |
residual_connection=True, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
image_only_indicator: torch.FloatTensor, | |
): | |
hidden_states = self.resnets[0]( | |
hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
for resnet, attn in zip(self.resnets[1:], self.attentions): | |
hidden_states = attn(hidden_states) | |
hidden_states = resnet( | |
hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
return hidden_states | |
class UpBlockTemporalDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
input_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=input_channels, | |
out_channels=out_channels, | |
temb_channels=None, | |
eps=1e-6, | |
temporal_eps=1e-5, | |
merge_factor=0.0, | |
merge_strategy="learned", | |
switch_spatial_to_temporal_mix=True, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
image_only_indicator: torch.FloatTensor, | |
) -> torch.FloatTensor: | |
for resnet in self.resnets: | |
hidden_states = resnet( | |
hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states) | |
return hidden_states | |
class UNetMidBlockSpatioTemporal(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
temb_channels: int, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
): | |
super().__init__() | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
# support for variable transformer layers per block | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
# there is always at least one resnet | |
resnets = [ | |
SpatioTemporalResBlock( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=1e-5, | |
) | |
] | |
attentions = [] | |
for i in range(num_layers): | |
attentions.append( | |
TransformerSpatioTemporalModel( | |
num_attention_heads, | |
in_channels // num_attention_heads, | |
in_channels=in_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
) | |
) | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=1e-5, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> torch.FloatTensor: | |
hidden_states = self.resnets[0]( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
if self.training and self.gradient_checkpointing: # TODO | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
hidden_states = resnet( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
return hidden_states | |
class DownBlockSpatioTemporal(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
num_layers: int = 1, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=1e-5, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
for resnet in self.resnets: | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if is_torch_version(">=", "1.11.0"): | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
use_reentrant=False, | |
) | |
else: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
) | |
else: | |
hidden_states = resnet( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
output_states = output_states + (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
class CrossAttnDownBlockSpatioTemporal(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=1e-6, | |
) | |
) | |
attentions.append( | |
TransformerSpatioTemporalModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=1, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
blocks = list(zip(self.resnets, self.attentions)) | |
for resnet, attn in blocks: | |
if self.training and self.gradient_checkpointing: # TODO | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
output_states = output_states + (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
class UpBlockSpatioTemporal(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
prev_output_channel: int, | |
out_channels: int, | |
temb_channels: int, | |
resolution_idx: Optional[int] = None, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> torch.FloatTensor: | |
for resnet in self.resnets: | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if is_torch_version(">=", "1.11.0"): | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
use_reentrant=False, | |
) | |
else: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
) | |
else: | |
hidden_states = resnet( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states) | |
return hidden_states | |
class CrossAttnUpBlockSpatioTemporal(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
resolution_idx: Optional[int] = None, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
resnet_eps: float = 1e-6, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
SpatioTemporalResBlock( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
) | |
) | |
attentions.append( | |
TransformerSpatioTemporalModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
image_only_indicator: Optional[torch.Tensor] = None, | |
) -> torch.FloatTensor: | |
for resnet, attn in zip(self.resnets, self.attentions): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: # TODO | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
image_only_indicator, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet( | |
hidden_states, | |
temb, | |
image_only_indicator=image_only_indicator, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
return_dict=False, | |
)[0] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states) | |
return hidden_states | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin | |
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers | |
from ..activations import get_activation | |
from ..attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
Attention, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from ..embeddings import ( | |
GaussianFourierProjection, | |
GLIGENTextBoundingboxProjection, | |
ImageHintTimeEmbedding, | |
ImageProjection, | |
ImageTimeEmbedding, | |
TextImageProjection, | |
TextImageTimeEmbedding, | |
TextTimeEmbedding, | |
TimestepEmbedding, | |
Timesteps, | |
) | |
from ..modeling_utils import ModelMixin | |
from .unet_2d_blocks import ( | |
get_down_block, | |
get_mid_block, | |
get_up_block, | |
) | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
@dataclass | |
class UNet2DConditionOutput(BaseOutput): | |
""" | |
The output of [`UNet2DConditionModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
""" | |
sample: torch.FloatTensor = None | |
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): | |
r""" | |
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample | |
shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
Height and width of input/output sample. | |
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. | |
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): | |
Whether to flip the sin to cos in the time embedding. | |
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): | |
The tuple of downsample blocks to use. | |
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): | |
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or | |
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): | |
The tuple of upsample blocks to use. | |
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): | |
Whether to include self-attention in the basic transformer blocks, see | |
[`~models.attention.BasicTransformerBlock`]. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. | |
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. | |
If `None`, normalization and activation layers is skipped in post-processing. | |
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. | |
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
The dimension of the cross attention features. | |
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], | |
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling | |
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for | |
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], | |
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
encoder_hid_dim (`int`, *optional*, defaults to None): | |
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` | |
dimension to `cross_attention_dim`. | |
encoder_hid_dim_type (`str`, *optional*, defaults to `None`): | |
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text | |
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. | |
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. | |
num_attention_heads (`int`, *optional*): | |
The number of attention heads. If not defined, defaults to `attention_head_dim` | |
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config | |
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. | |
class_embed_type (`str`, *optional*, defaults to `None`): | |
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, | |
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. | |
addition_embed_type (`str`, *optional*, defaults to `None`): | |
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or | |
"text". "text" will use the `TextTimeEmbedding` layer. | |
addition_time_embed_dim: (`int`, *optional*, defaults to `None`): | |
Dimension for the timestep embeddings. | |
num_class_embeds (`int`, *optional*, defaults to `None`): | |
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing | |
class conditioning with `class_embed_type` equal to `None`. | |
time_embedding_type (`str`, *optional*, defaults to `positional`): | |
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. | |
time_embedding_dim (`int`, *optional*, defaults to `None`): | |
An optional override for the dimension of the projected time embedding. | |
time_embedding_act_fn (`str`, *optional*, defaults to `None`): | |
Optional activation function to use only once on the time embeddings before they are passed to the rest of | |
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. | |
timestep_post_act (`str`, *optional*, defaults to `None`): | |
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. | |
time_cond_proj_dim (`int`, *optional*, defaults to `None`): | |
The dimension of `cond_proj` layer in the timestep embedding. | |
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, | |
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, | |
*optional*): The dimension of the `class_labels` input when | |
`class_embed_type="projection"`. Required when `class_embed_type="projection"`. | |
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time | |
embeddings with the class embeddings. | |
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): | |
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If | |
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the | |
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` | |
otherwise. | |
""" | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 4, | |
out_channels: int = 4, | |
center_input_sample: bool = False, | |
flip_sin_to_cos: bool = True, | |
freq_shift: int = 0, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
), | |
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", | |
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), | |
only_cross_attention: Union[bool, Tuple[bool]] = False, | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
downsample_padding: int = 1, | |
mid_block_scale_factor: float = 1, | |
dropout: float = 0.0, | |
act_fn: str = "silu", | |
norm_num_groups: Optional[int] = 32, | |
norm_eps: float = 1e-5, | |
cross_attention_dim: Union[int, Tuple[int]] = 1280, | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, | |
encoder_hid_dim: Optional[int] = None, | |
encoder_hid_dim_type: Optional[str] = None, | |
attention_head_dim: Union[int, Tuple[int]] = 8, | |
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
class_embed_type: Optional[str] = None, | |
addition_embed_type: Optional[str] = None, | |
addition_time_embed_dim: Optional[int] = None, | |
num_class_embeds: Optional[int] = None, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
resnet_skip_time_act: bool = False, | |
resnet_out_scale_factor: float = 1.0, | |
time_embedding_type: str = "positional", | |
time_embedding_dim: Optional[int] = None, | |
time_embedding_act_fn: Optional[str] = None, | |
timestep_post_act: Optional[str] = None, | |
time_cond_proj_dim: Optional[int] = None, | |
conv_in_kernel: int = 3, | |
conv_out_kernel: int = 3, | |
projection_class_embeddings_input_dim: Optional[int] = None, | |
attention_type: str = "default", | |
class_embeddings_concat: bool = False, | |
mid_block_only_cross_attention: Optional[bool] = None, | |
cross_attention_norm: Optional[str] = None, | |
addition_embed_type_num_heads: int = 64, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
if num_attention_heads is not None: | |
raise ValueError( | |
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." | |
) | |
# If `num_attention_heads` is not defined (which is the case for most models) | |
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is. | |
# The reason for this behavior is to correct for incorrectly named variables that were introduced | |
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | |
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking | |
# which is why we correct for the naming here. | |
num_attention_heads = num_attention_heads or attention_head_dim | |
# Check inputs | |
self._check_config( | |
down_block_types=down_block_types, | |
up_block_types=up_block_types, | |
only_cross_attention=only_cross_attention, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
cross_attention_dim=cross_attention_dim, | |
transformer_layers_per_block=transformer_layers_per_block, | |
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, | |
attention_head_dim=attention_head_dim, | |
num_attention_heads=num_attention_heads, | |
) | |
# input | |
conv_in_padding = (conv_in_kernel - 1) // 2 | |
self.conv_in = nn.Conv2d( | |
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding | |
) | |
# time | |
time_embed_dim, timestep_input_dim = self._set_time_proj( | |
time_embedding_type, | |
block_out_channels=block_out_channels, | |
flip_sin_to_cos=flip_sin_to_cos, | |
freq_shift=freq_shift, | |
time_embedding_dim=time_embedding_dim, | |
) | |
self.time_embedding = TimestepEmbedding( | |
timestep_input_dim, | |
time_embed_dim, | |
act_fn=act_fn, | |
post_act_fn=timestep_post_act, | |
cond_proj_dim=time_cond_proj_dim, | |
) | |
self._set_encoder_hid_proj( | |
encoder_hid_dim_type, | |
cross_attention_dim=cross_attention_dim, | |
encoder_hid_dim=encoder_hid_dim, | |
) | |
# class embedding | |
self._set_class_embedding( | |
class_embed_type, | |
act_fn=act_fn, | |
num_class_embeds=num_class_embeds, | |
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
time_embed_dim=time_embed_dim, | |
timestep_input_dim=timestep_input_dim, | |
) | |
self._set_add_embedding( | |
addition_embed_type, | |
addition_embed_type_num_heads=addition_embed_type_num_heads, | |
addition_time_embed_dim=addition_time_embed_dim, | |
cross_attention_dim=cross_attention_dim, | |
encoder_hid_dim=encoder_hid_dim, | |
flip_sin_to_cos=flip_sin_to_cos, | |
freq_shift=freq_shift, | |
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
time_embed_dim=time_embed_dim, | |
) | |
if time_embedding_act_fn is None: | |
self.time_embed_act = None | |
else: | |
self.time_embed_act = get_activation(time_embedding_act_fn) | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(only_cross_attention, bool): | |
if mid_block_only_cross_attention is None: | |
mid_block_only_cross_attention = only_cross_attention | |
only_cross_attention = [only_cross_attention] * len(down_block_types) | |
if mid_block_only_cross_attention is None: | |
mid_block_only_cross_attention = False | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
if isinstance(attention_head_dim, int): | |
attention_head_dim = (attention_head_dim,) * len(down_block_types) | |
if isinstance(cross_attention_dim, int): | |
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | |
if isinstance(layers_per_block, int): | |
layers_per_block = [layers_per_block] * len(down_block_types) | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
if class_embeddings_concat: | |
# The time embeddings are concatenated with the class embeddings. The dimension of the | |
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the | |
# regular time embeddings | |
blocks_time_embed_dim = time_embed_dim * 2 | |
else: | |
blocks_time_embed_dim = time_embed_dim | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block[i], | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim[i], | |
num_attention_heads=num_attention_heads[i], | |
downsample_padding=downsample_padding, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
resnet_skip_time_act=resnet_skip_time_act, | |
resnet_out_scale_factor=resnet_out_scale_factor, | |
cross_attention_norm=cross_attention_norm, | |
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
dropout=dropout, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = get_mid_block( | |
mid_block_type, | |
temb_channels=blocks_time_embed_dim, | |
in_channels=block_out_channels[-1], | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
output_scale_factor=mid_block_scale_factor, | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
num_attention_heads=num_attention_heads[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
mid_block_only_cross_attention=mid_block_only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
resnet_skip_time_act=resnet_skip_time_act, | |
cross_attention_norm=cross_attention_norm, | |
attention_head_dim=attention_head_dim[-1], | |
dropout=dropout, | |
) | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
reversed_layers_per_block = list(reversed(layers_per_block)) | |
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | |
reversed_transformer_layers_per_block = ( | |
list(reversed(transformer_layers_per_block)) | |
if reverse_transformer_layers_per_block is None | |
else reverse_transformer_layers_per_block | |
) | |
only_cross_attention = list(reversed(only_cross_attention)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=reversed_layers_per_block[i] + 1, | |
transformer_layers_per_block=reversed_transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resolution_idx=i, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=reversed_cross_attention_dim[i], | |
num_attention_heads=reversed_num_attention_heads[i], | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
resnet_skip_time_act=resnet_skip_time_act, | |
resnet_out_scale_factor=resnet_out_scale_factor, | |
cross_attention_norm=cross_attention_norm, | |
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
dropout=dropout, | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
if norm_num_groups is not None: | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
) | |
self.conv_act = get_activation(act_fn) | |
else: | |
self.conv_norm_out = None | |
self.conv_act = None | |
conv_out_padding = (conv_out_kernel - 1) // 2 | |
self.conv_out = nn.Conv2d( | |
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding | |
) | |
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) | |
def _check_config( | |
self, | |
down_block_types: Tuple[str], | |
up_block_types: Tuple[str], | |
only_cross_attention: Union[bool, Tuple[bool]], | |
block_out_channels: Tuple[int], | |
layers_per_block: Union[int, Tuple[int]], | |
cross_attention_dim: Union[int, Tuple[int]], | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], | |
reverse_transformer_layers_per_block: bool, | |
attention_head_dim: int, | |
num_attention_heads: Optional[Union[int, Tuple[int]]], | |
): | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: | |
for layer_number_per_block in transformer_layers_per_block: | |
if isinstance(layer_number_per_block, list): | |
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") | |
def _set_time_proj( | |
self, | |
time_embedding_type: str, | |
block_out_channels: int, | |
flip_sin_to_cos: bool, | |
freq_shift: float, | |
time_embedding_dim: int, | |
) -> Tuple[int, int]: | |
if time_embedding_type == "fourier": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 | |
if time_embed_dim % 2 != 0: | |
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") | |
self.time_proj = GaussianFourierProjection( | |
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
timestep_input_dim = time_embed_dim | |
elif time_embedding_type == "positional": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
timestep_input_dim = block_out_channels[0] | |
else: | |
raise ValueError( | |
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." | |
) | |
return time_embed_dim, timestep_input_dim | |
def _set_encoder_hid_proj( | |
self, | |
encoder_hid_dim_type: Optional[str], | |
cross_attention_dim: Union[int, Tuple[int]], | |
encoder_hid_dim: Optional[int], | |
): | |
if encoder_hid_dim_type is None and encoder_hid_dim is not None: | |
encoder_hid_dim_type = "text_proj" | |
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) | |
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") | |
if encoder_hid_dim is None and encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." | |
) | |
if encoder_hid_dim_type == "text_proj": | |
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) | |
elif encoder_hid_dim_type == "text_image_proj": | |
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` | |
self.encoder_hid_proj = TextImageProjection( | |
text_embed_dim=encoder_hid_dim, | |
image_embed_dim=cross_attention_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif encoder_hid_dim_type == "image_proj": | |
# Kandinsky 2.2 | |
self.encoder_hid_proj = ImageProjection( | |
image_embed_dim=encoder_hid_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." | |
) | |
else: | |
self.encoder_hid_proj = None | |
def _set_class_embedding( | |
self, | |
class_embed_type: Optional[str], | |
act_fn: str, | |
num_class_embeds: Optional[int], | |
projection_class_embeddings_input_dim: Optional[int], | |
time_embed_dim: int, | |
timestep_input_dim: int, | |
): | |
if class_embed_type is None and num_class_embeds is not None: | |
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) | |
elif class_embed_type == "timestep": | |
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) | |
elif class_embed_type == "identity": | |
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) | |
elif class_embed_type == "projection": | |
if projection_class_embeddings_input_dim is None: | |
raise ValueError( | |
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" | |
) | |
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except | |
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings | |
# 2. it projects from an arbitrary input dimension. | |
# | |
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. | |
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. | |
# As a result, `TimestepEmbedding` can be passed arbitrary vectors. | |
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
elif class_embed_type == "simple_projection": | |
if projection_class_embeddings_input_dim is None: | |
raise ValueError( | |
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" | |
) | |
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) | |
else: | |
self.class_embedding = None | |
def _set_add_embedding( | |
self, | |
addition_embed_type: str, | |
addition_embed_type_num_heads: int, | |
addition_time_embed_dim: Optional[int], | |
flip_sin_to_cos: bool, | |
freq_shift: float, | |
cross_attention_dim: Optional[int], | |
encoder_hid_dim: Optional[int], | |
projection_class_embeddings_input_dim: Optional[int], | |
time_embed_dim: int, | |
): | |
if addition_embed_type == "text": | |
if encoder_hid_dim is not None: | |
text_time_embedding_from_dim = encoder_hid_dim | |
else: | |
text_time_embedding_from_dim = cross_attention_dim | |
self.add_embedding = TextTimeEmbedding( | |
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads | |
) | |
elif addition_embed_type == "text_image": | |
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` | |
self.add_embedding = TextImageTimeEmbedding( | |
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim | |
) | |
elif addition_embed_type == "text_time": | |
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) | |
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
elif addition_embed_type == "image": | |
# Kandinsky 2.2 | |
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
elif addition_embed_type == "image_hint": | |
# Kandinsky 2.2 ControlNet | |
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
elif addition_embed_type is not None: | |
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") | |
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): | |
if attention_type in ["gated", "gated-text-image"]: | |
positive_len = 768 | |
if isinstance(cross_attention_dim, int): | |
positive_len = cross_attention_dim | |
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): | |
positive_len = cross_attention_dim[0] | |
feature_type = "text-only" if attention_type == "gated" else "text-image" | |
self.position_net = GLIGENTextBoundingboxProjection( | |
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type | |
) | |
@property | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): | |
r""" | |
Enable sliced attention computation. | |
When this option is enabled, the attention module splits the input tensor in slices to compute attention in | |
several steps. This is useful for saving some memory in exchange for a small decrease in speed. | |
Args: | |
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | |
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | |
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | |
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | |
must be a multiple of `slice_size`. | |
""" | |
sliceable_head_dims = [] | |
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): | |
if hasattr(module, "set_attention_slice"): | |
sliceable_head_dims.append(module.sliceable_head_dim) | |
for child in module.children(): | |
fn_recursive_retrieve_sliceable_dims(child) | |
# retrieve number of attention layers | |
for module in self.children(): | |
fn_recursive_retrieve_sliceable_dims(module) | |
num_sliceable_layers = len(sliceable_head_dims) | |
if slice_size == "auto": | |
# half the attention head size is usually a good trade-off between | |
# speed and memory | |
slice_size = [dim // 2 for dim in sliceable_head_dims] | |
elif slice_size == "max": | |
# make smallest slice possible | |
slice_size = num_sliceable_layers * [1] | |
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size | |
if len(slice_size) != len(sliceable_head_dims): | |
raise ValueError( | |
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" | |
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." | |
) | |
for i in range(len(slice_size)): | |
size = slice_size[i] | |
dim = sliceable_head_dims[i] | |
if size is not None and size > dim: | |
raise ValueError(f"size {size} has to be smaller or equal to {dim}.") | |
# Recursively walk through all the children. | |
# Any children which exposes the set_attention_slice method | |
# gets the message | |
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): | |
if hasattr(module, "set_attention_slice"): | |
module.set_attention_slice(slice_size.pop()) | |
for child in module.children(): | |
fn_recursive_set_attention_slice(child, slice_size) | |
reversed_slice_size = list(reversed(slice_size)) | |
for module in self.children(): | |
fn_recursive_set_attention_slice(module, reversed_slice_size) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): | |
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. | |
The suffixes after the scaling factors represent the stage blocks where they are being applied. | |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that | |
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. | |
Args: | |
s1 (`float`): | |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
s2 (`float`): | |
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. | |
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. | |
""" | |
for i, upsample_block in enumerate(self.up_blocks): | |
setattr(upsample_block, "s1", s1) | |
setattr(upsample_block, "s2", s2) | |
setattr(upsample_block, "b1", b1) | |
setattr(upsample_block, "b2", b2) | |
def disable_freeu(self): | |
"""Disables the FreeU mechanism.""" | |
freeu_keys = {"s1", "s2", "b1", "b2"} | |
for i, upsample_block in enumerate(self.up_blocks): | |
for k in freeu_keys: | |
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |
setattr(upsample_block, k, None) | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, | |
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
def unload_lora(self): | |
"""Unloads LoRA weights.""" | |
deprecate( | |
"unload_lora", | |
"0.28.0", | |
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", | |
) | |
for module in self.modules(): | |
if hasattr(module, "set_lora_layer"): | |
module.set_lora_layer(None) | |
def get_time_embed( | |
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] | |
) -> Optional[torch.Tensor]: | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
return t_emb | |
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
class_emb = None | |
if self.class_embedding is not None: | |
if class_labels is None: | |
raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
if self.config.class_embed_type == "timestep": | |
class_labels = self.time_proj(class_labels) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# there might be better ways to encapsulate this. | |
class_labels = class_labels.to(dtype=sample.dtype) | |
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) | |
return class_emb | |
def get_aug_embed( | |
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
) -> Optional[torch.Tensor]: | |
aug_emb = None | |
if self.config.addition_embed_type == "text": | |
aug_emb = self.add_embedding(encoder_hidden_states) | |
elif self.config.addition_embed_type == "text_image": | |
# Kandinsky 2.1 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) | |
aug_emb = self.add_embedding(text_embs, image_embs) | |
elif self.config.addition_embed_type == "text_time": | |
# SDXL - style | |
if "text_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
) | |
text_embeds = added_cond_kwargs.get("text_embeds") | |
if "time_ids" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
) | |
time_ids = added_cond_kwargs.get("time_ids") | |
time_embeds = self.add_time_proj(time_ids.flatten()) | |
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
add_embeds = add_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(add_embeds) | |
elif self.config.addition_embed_type == "image": | |
# Kandinsky 2.2 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
aug_emb = self.add_embedding(image_embs) | |
elif self.config.addition_embed_type == "image_hint": | |
# Kandinsky 2.2 - style | |
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
hint = added_cond_kwargs.get("hint") | |
aug_emb = self.add_embedding(image_embs, hint) | |
return aug_emb | |
def process_encoder_hidden_states( | |
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
) -> torch.Tensor: | |
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": | |
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": | |
# Kadinsky 2.1 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": | |
# Kandinsky 2.2 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
encoder_hidden_states = self.encoder_hid_proj(image_embeds) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
image_embeds = self.encoder_hid_proj(image_embeds) | |
encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
return encoder_hidden_states | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
class_labels: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
mid_block_additional_residual: Optional[torch.Tensor] = None, | |
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
) -> Union[UNet2DConditionOutput, Tuple]: | |
r""" | |
The [`UNet2DConditionModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.FloatTensor`): | |
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
through the `self.time_embedding` layer to obtain the timestep embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
added_cond_kwargs: (`dict`, *optional*): | |
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that | |
are passed along to the UNet blocks. | |
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
A tensor that if specified is added to the residual of the middle unet block. | |
encoder_attention_mask (`torch.Tensor`): | |
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | |
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | |
which adds large negative values to the attention scores corresponding to "discard" tokens. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain | |
tuple. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. | |
added_cond_kwargs: (`dict`, *optional*): | |
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that | |
are passed along to the UNet blocks. | |
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
additional residuals to be added to UNet long skip connections from down blocks to up blocks for | |
example from ControlNet side model(s) | |
mid_block_additional_residual (`torch.Tensor`, *optional*): | |
additional residual to be added to UNet mid block output, for example from ControlNet side model | |
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | |
Returns: | |
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise | |
a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
for dim in sample.shape[-2:]: | |
if dim % default_overall_up_factor != 0: | |
# Forward upsample size to force interpolation output size. | |
forward_upsample_size = True | |
break | |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension | |
# expects mask of shape: | |
# [batch, key_tokens] | |
# adds singleton query_tokens dimension: | |
# [batch, 1, key_tokens] | |
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
if attention_mask is not None: | |
# assume that mask is expressed as: | |
# (1 = keep, 0 = discard) | |
# convert mask into a bias that can be added to attention scores: | |
# (keep = +0, discard = -10000.0) | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
# 0. center input if necessary | |
if self.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# 1. time | |
t_emb = self.get_time_embed(sample=sample, timestep=timestep) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
aug_emb = None | |
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) | |
if class_emb is not None: | |
if self.config.class_embeddings_concat: | |
emb = torch.cat([emb, class_emb], dim=-1) | |
else: | |
emb = emb + class_emb | |
aug_emb = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
if self.config.addition_embed_type == "image_hint": | |
aug_emb, hint = aug_emb | |
sample = torch.cat([sample, hint], dim=1) | |
emb = emb + aug_emb if aug_emb is not None else emb | |
if self.time_embed_act is not None: | |
emb = self.time_embed_act(emb) | |
encoder_hidden_states = self.process_encoder_hidden_states( | |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
# 2.5 GLIGEN position net | |
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: | |
cross_attention_kwargs = cross_attention_kwargs.copy() | |
gligen_args = cross_attention_kwargs.pop("gligen") | |
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} | |
# 3. down | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
if USE_PEFT_BACKEND: | |
# weight the lora layers by setting `lora_scale` for each PEFT layer | |
scale_lora_layers(self, lora_scale) | |
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None | |
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets | |
is_adapter = down_intrablock_additional_residuals is not None | |
# maintain backward compatibility for legacy usage, where | |
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg | |
# but can only use one or the other | |
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: | |
deprecate( | |
"T2I should not use down_block_additional_residuals", | |
"1.3.0", | |
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ | |
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ | |
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", | |
standard_warn=False, | |
) | |
down_intrablock_additional_residuals = down_block_additional_residuals | |
is_adapter = True | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
# For t2i-adapter CrossAttnDownBlock2D | |
additional_residuals = {} | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
encoder_attention_mask=encoder_attention_mask, | |
**additional_residuals, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
sample += down_intrablock_additional_residuals.pop(0) | |
down_block_res_samples += res_samples | |
if is_controlnet: | |
new_down_block_res_samples = () | |
for down_block_res_sample, down_block_additional_residual in zip( | |
down_block_res_samples, down_block_additional_residuals | |
): | |
down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) | |
down_block_res_samples = new_down_block_res_samples | |
# 4. mid | |
if self.mid_block is not None: | |
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: | |
sample = self.mid_block( | |
sample, | |
emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
encoder_attention_mask=encoder_attention_mask, | |
) | |
else: | |
sample = self.mid_block(sample, emb) | |
# To support T2I-Adapter-XL | |
if ( | |
is_adapter | |
and len(down_intrablock_additional_residuals) > 0 | |
and sample.shape == down_intrablock_additional_residuals[0].shape | |
): | |
sample += down_intrablock_additional_residuals.pop(0) | |
if is_controlnet: | |
sample = sample + mid_block_additional_residual | |
# 5. up | |
for i, upsample_block in enumerate(self.up_blocks): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
upsample_size=upsample_size, | |
scale=lora_scale, | |
) | |
# 6. post-process | |
if self.conv_norm_out: | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
if USE_PEFT_BACKEND: | |
# remove `lora_scale` from each PEFT layer | |
unscale_lora_layers(self, lora_scale) | |
if not return_dict: | |
return (sample,) | |
return UNet2DConditionOutput(sample=sample) | |
from dataclasses import dataclass | |
from typing import Dict, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import BaseOutput, logging | |
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor | |
from ..embeddings import TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
@dataclass | |
class Kandinsky3UNetOutput(BaseOutput): | |
sample: torch.FloatTensor = None | |
class Kandinsky3EncoderProj(nn.Module): | |
def __init__(self, encoder_hid_dim, cross_attention_dim): | |
super().__init__() | |
self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) | |
self.projection_norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, x): | |
x = self.projection_linear(x) | |
x = self.projection_norm(x) | |
return x | |
class Kandinsky3UNet(ModelMixin, ConfigMixin): | |
@register_to_config | |
def __init__( | |
self, | |
in_channels: int = 4, | |
time_embedding_dim: int = 1536, | |
groups: int = 32, | |
attention_head_dim: int = 64, | |
layers_per_block: Union[int, Tuple[int]] = 3, | |
block_out_channels: Tuple[int] = (384, 768, 1536, 3072), | |
cross_attention_dim: Union[int, Tuple[int]] = 4096, | |
encoder_hid_dim: int = 4096, | |
): | |
super().__init__() | |
# TOOD(Yiyi): Give better name and put into config for the following 4 parameters | |
expansion_ratio = 4 | |
compression_ratio = 2 | |
add_cross_attention = (False, True, True, True) | |
add_self_attention = (False, True, True, True) | |
out_channels = in_channels | |
init_channels = block_out_channels[0] // 2 | |
self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) | |
self.time_embedding = TimestepEmbedding( | |
init_channels, | |
time_embedding_dim, | |
) | |
self.add_time_condition = Kandinsky3AttentionPooling( | |
time_embedding_dim, cross_attention_dim, attention_head_dim | |
) | |
self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1) | |
self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) | |
hidden_dims = [init_channels] + list(block_out_channels) | |
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) | |
text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] | |
num_blocks = len(block_out_channels) * [layers_per_block] | |
layer_params = [num_blocks, text_dims, add_self_attention] | |
rev_layer_params = map(reversed, layer_params) | |
cat_dims = [] | |
self.num_levels = len(in_out_dims) | |
self.down_blocks = nn.ModuleList([]) | |
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( | |
zip(in_out_dims, *layer_params) | |
): | |
down_sample = level != (self.num_levels - 1) | |
cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) | |
self.down_blocks.append( | |
Kandinsky3DownSampleBlock( | |
in_dim, | |
out_dim, | |
time_embedding_dim, | |
text_dim, | |
res_block_num, | |
groups, | |
attention_head_dim, | |
expansion_ratio, | |
compression_ratio, | |
down_sample, | |
self_attention, | |
) | |
) | |
self.up_blocks = nn.ModuleList([]) | |
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( | |
zip(reversed(in_out_dims), *rev_layer_params) | |
): | |
up_sample = level != 0 | |
self.up_blocks.append( | |
Kandinsky3UpSampleBlock( | |
in_dim, | |
cat_dims.pop(), | |
out_dim, | |
time_embedding_dim, | |
text_dim, | |
res_block_num, | |
groups, | |
attention_head_dim, | |
expansion_ratio, | |
compression_ratio, | |
up_sample, | |
self_attention, | |
) | |
) | |
self.conv_norm_out = nn.GroupNorm(groups, init_channels) | |
self.conv_act_out = nn.SiLU() | |
self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) | |
@property | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "set_processor"): | |
processors[f"{name}.processor"] = module.processor | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
self.set_attn_processor(AttnProcessor()) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): | |
if encoder_attention_mask is not None: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
if not torch.is_tensor(timestep): | |
dtype = torch.float32 if isinstance(timestep, float) else torch.int32 | |
timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) | |
elif len(timestep.shape) == 0: | |
timestep = timestep[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = timestep.expand(sample.shape[0]) | |
time_embed_input = self.time_proj(timestep).to(sample.dtype) | |
time_embed = self.time_embedding(time_embed_input) | |
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
if encoder_hidden_states is not None: | |
time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) | |
hidden_states = [] | |
sample = self.conv_in(sample) | |
for level, down_sample in enumerate(self.down_blocks): | |
sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) | |
if level != self.num_levels - 1: | |
hidden_states.append(sample) | |
for level, up_sample in enumerate(self.up_blocks): | |
if level != 0: | |
sample = torch.cat([sample, hidden_states.pop()], dim=1) | |
sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act_out(sample) | |
sample = self.conv_out(sample) | |
if not return_dict: | |
return (sample,) | |
return Kandinsky3UNetOutput(sample=sample) | |
class Kandinsky3UpSampleBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
cat_dim, | |
out_channels, | |
time_embed_dim, | |
context_dim=None, | |
num_blocks=3, | |
groups=32, | |
head_dim=64, | |
expansion_ratio=4, | |
compression_ratio=2, | |
up_sample=True, | |
self_attention=True, | |
): | |
super().__init__() | |
up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) | |
hidden_channels = ( | |
[(in_channels + cat_dim, in_channels)] | |
+ [(in_channels, in_channels)] * (num_blocks - 2) | |
+ [(in_channels, out_channels)] | |
) | |
attentions = [] | |
resnets_in = [] | |
resnets_out = [] | |
self.self_attention = self_attention | |
self.context_dim = context_dim | |
if self_attention: | |
attentions.append( | |
Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) | |
) | |
else: | |
attentions.append(nn.Identity()) | |
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): | |
resnets_in.append( | |
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) | |
) | |
if context_dim is not None: | |
attentions.append( | |
Kandinsky3AttentionBlock( | |
in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio | |
) | |
) | |
else: | |
attentions.append(nn.Identity()) | |
resnets_out.append( | |
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets_in = nn.ModuleList(resnets_in) | |
self.resnets_out = nn.ModuleList(resnets_out) | |
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): | |
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): | |
x = resnet_in(x, time_embed) | |
if self.context_dim is not None: | |
x = attention(x, time_embed, context, context_mask, image_mask) | |
x = resnet_out(x, time_embed) | |
if self.self_attention: | |
x = self.attentions[0](x, time_embed, image_mask=image_mask) | |
return x | |
class Kandinsky3DownSampleBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
time_embed_dim, | |
context_dim=None, | |
num_blocks=3, | |
groups=32, | |
head_dim=64, | |
expansion_ratio=4, | |
compression_ratio=2, | |
down_sample=True, | |
self_attention=True, | |
): | |
super().__init__() | |
attentions = [] | |
resnets_in = [] | |
resnets_out = [] | |
self.self_attention = self_attention | |
self.context_dim = context_dim | |
if self_attention: | |
attentions.append( | |
Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) | |
) | |
else: | |
attentions.append(nn.Identity()) | |
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] | |
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) | |
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): | |
resnets_in.append( | |
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) | |
) | |
if context_dim is not None: | |
attentions.append( | |
Kandinsky3AttentionBlock( | |
out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio | |
) | |
) | |
else: | |
attentions.append(nn.Identity()) | |
resnets_out.append( | |
Kandinsky3ResNetBlock( | |
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets_in = nn.ModuleList(resnets_in) | |
self.resnets_out = nn.ModuleList(resnets_out) | |
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): | |
if self.self_attention: | |
x = self.attentions[0](x, time_embed, image_mask=image_mask) | |
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): | |
x = resnet_in(x, time_embed) | |
if self.context_dim is not None: | |
x = attention(x, time_embed, context, context_mask, image_mask) | |
x = resnet_out(x, time_embed) | |
return x | |
class Kandinsky3ConditionalGroupNorm(nn.Module): | |
def __init__(self, groups, normalized_shape, context_dim): | |
super().__init__() | |
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) | |
self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape)) | |
self.context_mlp[1].weight.data.zero_() | |
self.context_mlp[1].bias.data.zero_() | |
def forward(self, x, context): | |
context = self.context_mlp(context) | |
for _ in range(len(x.shape[2:])): | |
context = context.unsqueeze(-1) | |
scale, shift = context.chunk(2, dim=1) | |
x = self.norm(x) * (scale + 1.0) + shift | |
return x | |
class Kandinsky3Block(nn.Module): | |
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): | |
super().__init__() | |
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) | |
self.activation = nn.SiLU() | |
if up_resolution is not None and up_resolution: | |
self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) | |
else: | |
self.up_sample = nn.Identity() | |
padding = int(kernel_size > 1) | |
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) | |
if up_resolution is not None and not up_resolution: | |
self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) | |
else: | |
self.down_sample = nn.Identity() | |
def forward(self, x, time_embed): | |
x = self.group_norm(x, time_embed) | |
x = self.activation(x) | |
x = self.up_sample(x) | |
x = self.projection(x) | |
x = self.down_sample(x) | |
return x | |
class Kandinsky3ResNetBlock(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] | |
): | |
super().__init__() | |
kernel_sizes = [1, 3, 3, 1] | |
hidden_channel = max(in_channels, out_channels) // compression_ratio | |
hidden_channels = ( | |
[(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] | |
) | |
self.resnet_blocks = nn.ModuleList( | |
[ | |
Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) | |
for (in_channel, out_channel), kernel_size, up_resolution in zip( | |
hidden_channels, kernel_sizes, up_resolutions | |
) | |
] | |
) | |
self.shortcut_up_sample = ( | |
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) | |
if True in up_resolutions | |
else nn.Identity() | |
) | |
self.shortcut_projection = ( | |
nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() | |
) | |
self.shortcut_down_sample = ( | |
nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) | |
if False in up_resolutions | |
else nn.Identity() | |
) | |
def forward(self, x, time_embed): | |
out = x | |
for resnet_block in self.resnet_blocks: | |
out = resnet_block(out, time_embed) | |
x = self.shortcut_up_sample(x) | |
x = self.shortcut_projection(x) | |
x = self.shortcut_down_sample(x) | |
x = x + out | |
return x | |
class Kandinsky3AttentionPooling(nn.Module): | |
def __init__(self, num_channels, context_dim, head_dim=64): | |
super().__init__() | |
self.attention = Attention( | |
context_dim, | |
context_dim, | |
dim_head=head_dim, | |
out_dim=num_channels, | |
out_bias=False, | |
) | |
def forward(self, x, context, context_mask=None): | |
context_mask = context_mask.to(dtype=context.dtype) | |
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) | |
return x + context.squeeze(1) | |
class Kandinsky3AttentionBlock(nn.Module): | |
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): | |
super().__init__() | |
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) | |
self.attention = Attention( | |
num_channels, | |
context_dim or num_channels, | |
dim_head=head_dim, | |
out_dim=num_channels, | |
out_bias=False, | |
) | |
hidden_channels = expansion_ratio * num_channels | |
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) | |
self.feed_forward = nn.Sequential( | |
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False), | |
nn.SiLU(), | |
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False), | |
) | |
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): | |
height, width = x.shape[-2:] | |
out = self.in_norm(x, time_embed) | |
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) | |
context = context if context is not None else out | |
if context_mask is not None: | |
context_mask = context_mask.to(dtype=context.dtype) | |
out = self.attention(out, context, context_mask) | |
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) | |
x = x + out | |
out = self.out_norm(x, time_embed) | |
out = self.feed_forward(out) | |
x = x + out | |
return x | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import UNet2DConditionLoadersMixin | |
from ...utils import BaseOutput, logging | |
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor | |
from ..embeddings import TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
@dataclass | |
class UNetSpatioTemporalConditionOutput(BaseOutput): | |
""" | |
The output of [`UNetSpatioTemporalConditionModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
""" | |
sample: torch.FloatTensor = None | |
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): | |
r""" | |
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample | |
shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
Height and width of input/output sample. | |
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): | |
The tuple of downsample blocks to use. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): | |
The tuple of upsample blocks to use. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
addition_time_embed_dim: (`int`, defaults to 256): | |
Dimension to to encode the additional time ids. | |
projection_class_embeddings_input_dim (`int`, defaults to 768): | |
The dimension of the projection of encoded `added_time_ids`. | |
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
The dimension of the cross attention features. | |
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], | |
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. | |
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): | |
The number of attention heads. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
""" | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 8, | |
out_channels: int = 4, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"DownBlockSpatioTemporal", | |
), | |
up_block_types: Tuple[str] = ( | |
"UpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
), | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
addition_time_embed_dim: int = 256, | |
projection_class_embeddings_input_dim: int = 768, | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
cross_attention_dim: Union[int, Tuple[int]] = 1024, | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), | |
num_frames: int = 25, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
# Check inputs | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | |
) | |
# input | |
self.conv_in = nn.Conv2d( | |
in_channels, | |
block_out_channels[0], | |
kernel_size=3, | |
padding=1, | |
) | |
# time | |
time_embed_dim = block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) | |
timestep_input_dim = block_out_channels[0] | |
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) | |
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) | |
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
if isinstance(cross_attention_dim, int): | |
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | |
if isinstance(layers_per_block, int): | |
layers_per_block = [layers_per_block] * len(down_block_types) | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
blocks_time_embed_dim = time_embed_dim | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block[i], | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=1e-5, | |
cross_attention_dim=cross_attention_dim[i], | |
num_attention_heads=num_attention_heads[i], | |
resnet_act_fn="silu", | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = UNetMidBlockSpatioTemporal( | |
block_out_channels[-1], | |
temb_channels=blocks_time_embed_dim, | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
num_attention_heads=num_attention_heads[-1], | |
) | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
reversed_layers_per_block = list(reversed(layers_per_block)) | |
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | |
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=reversed_layers_per_block[i] + 1, | |
transformer_layers_per_block=reversed_transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=1e-5, | |
resolution_idx=i, | |
cross_attention_dim=reversed_cross_attention_dim[i], | |
num_attention_heads=reversed_num_attention_heads[i], | |
resnet_act_fn="silu", | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) | |
self.conv_act = nn.SiLU() | |
self.conv_out = nn.Conv2d( | |
block_out_channels[0], | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
) | |
@property | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors( | |
name: str, | |
module: torch.nn.Module, | |
processors: Dict[str, AttentionProcessor], | |
): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | |
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: | |
""" | |
Sets the attention processor to use [feed forward | |
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). | |
Parameters: | |
chunk_size (`int`, *optional*): | |
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually | |
over each tensor of dim=`dim`. | |
dim (`int`, *optional*, defaults to `0`): | |
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) | |
or dim=1 (sequence length). | |
""" | |
if dim not in [0, 1]: | |
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") | |
# By default chunk size is 1 | |
chunk_size = chunk_size or 1 | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, chunk_size, dim) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
return_dict: bool = True, | |
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: | |
r""" | |
The [`UNetSpatioTemporalConditionModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.FloatTensor`): | |
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. | |
added_time_ids: (`torch.FloatTensor`): | |
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal | |
embeddings and added to the time embeddings. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise | |
a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
batch_size, num_frames = sample.shape[:2] | |
timesteps = timesteps.expand(batch_size) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb = self.time_embedding(t_emb) | |
time_embeds = self.add_time_proj(added_time_ids.flatten()) | |
time_embeds = time_embeds.reshape((batch_size, -1)) | |
time_embeds = time_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(time_embeds) | |
emb = emb + aug_emb | |
# Flatten the batch and frames dimensions | |
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] | |
sample = sample.flatten(0, 1) | |
# Repeat the embeddings num_video_frames times | |
# emb: [batch, channels] -> [batch * frames, channels] | |
emb = emb.repeat_interleave(num_frames, dim=0) | |
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] | |
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
else: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
image_only_indicator=image_only_indicator, | |
) | |
down_block_res_samples += res_samples | |
# 4. mid | |
sample = self.mid_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
# 5. up | |
for i, upsample_block in enumerate(self.up_blocks): | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
image_only_indicator=image_only_indicator, | |
) | |
# 6. post-process | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
# 7. Reshape back to original shape | |
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) | |
if not return_dict: | |
return (sample,) | |
return UNetSpatioTemporalConditionOutput(sample=sample) | |
import flax.linen as nn | |
import jax.numpy as jnp | |
from ..attention_flax import FlaxTransformer2DModel | |
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D | |
class FlaxCrossAttnDownBlock2D(nn.Module): | |
r""" | |
Cross Attention 2D Downsizing block - original architecture from Unet transformers: | |
https://arxiv.org/abs/2103.06104 | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of attention blocks layers | |
num_attention_heads (:obj:`int`, *optional*, defaults to 1): | |
Number of attention heads of each spatial transformer block | |
add_downsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add downsampling layer before each final output | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
num_attention_heads: int = 1 | |
add_downsample: bool = True | |
use_linear_projection: bool = False | |
only_cross_attention: bool = False | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
transformer_layers_per_block: int = 1 | |
def setup(self): | |
resnets = [] | |
attentions = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=self.out_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
attn_block = FlaxTransformer2DModel( | |
in_channels=self.out_channels, | |
n_heads=self.num_attention_heads, | |
d_head=self.out_channels // self.num_attention_heads, | |
depth=self.transformer_layers_per_block, | |
use_linear_projection=self.use_linear_projection, | |
only_cross_attention=self.only_cross_attention, | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
dtype=self.dtype, | |
) | |
attentions.append(attn_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
if self.add_downsample: | |
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): | |
output_states = () | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) | |
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) | |
output_states += (hidden_states,) | |
if self.add_downsample: | |
hidden_states = self.downsamplers_0(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class FlaxDownBlock2D(nn.Module): | |
r""" | |
Flax 2D downsizing block | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of attention blocks layers | |
add_downsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add downsampling layer before each final output | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
add_downsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
resnets = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=self.out_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_downsample: | |
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, temb, deterministic=True): | |
output_states = () | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) | |
output_states += (hidden_states,) | |
if self.add_downsample: | |
hidden_states = self.downsamplers_0(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class FlaxCrossAttnUpBlock2D(nn.Module): | |
r""" | |
Cross Attention 2D Upsampling block - original architecture from Unet transformers: | |
https://arxiv.org/abs/2103.06104 | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of attention blocks layers | |
num_attention_heads (:obj:`int`, *optional*, defaults to 1): | |
Number of attention heads of each spatial transformer block | |
add_upsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add upsampling layer before each final output | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
prev_output_channel: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
num_attention_heads: int = 1 | |
add_upsample: bool = True | |
use_linear_projection: bool = False | |
only_cross_attention: bool = False | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
transformer_layers_per_block: int = 1 | |
def setup(self): | |
resnets = [] | |
attentions = [] | |
for i in range(self.num_layers): | |
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels | |
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=self.out_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
attn_block = FlaxTransformer2DModel( | |
in_channels=self.out_channels, | |
n_heads=self.num_attention_heads, | |
d_head=self.out_channels // self.num_attention_heads, | |
depth=self.transformer_layers_per_block, | |
use_linear_projection=self.use_linear_projection, | |
only_cross_attention=self.only_cross_attention, | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
dtype=self.dtype, | |
) | |
attentions.append(attn_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
if self.add_upsample: | |
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): | |
for resnet, attn in zip(self.resnets, self.attentions): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) | |
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) | |
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) | |
if self.add_upsample: | |
hidden_states = self.upsamplers_0(hidden_states) | |
return hidden_states | |
class FlaxUpBlock2D(nn.Module): | |
r""" | |
Flax 2D upsampling block | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
out_channels (:obj:`int`): | |
Output channels | |
prev_output_channel (:obj:`int`): | |
Output channels from the previous block | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of attention blocks layers | |
add_downsample (:obj:`bool`, *optional*, defaults to `True`): | |
Whether to add downsampling layer before each final output | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
out_channels: int | |
prev_output_channel: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
add_upsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
resnets = [] | |
for i in range(self.num_layers): | |
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels | |
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels | |
res_block = FlaxResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=self.out_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_upsample: | |
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) | |
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): | |
for resnet in self.resnets: | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) | |
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) | |
if self.add_upsample: | |
hidden_states = self.upsamplers_0(hidden_states) | |
return hidden_states | |
class FlaxUNetMidBlock2DCrossAttn(nn.Module): | |
r""" | |
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 | |
Parameters: | |
in_channels (:obj:`int`): | |
Input channels | |
dropout (:obj:`float`, *optional*, defaults to 0.0): | |
Dropout rate | |
num_layers (:obj:`int`, *optional*, defaults to 1): | |
Number of attention blocks layers | |
num_attention_heads (:obj:`int`, *optional*, defaults to 1): | |
Number of attention heads of each spatial transformer block | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
enable memory efficient attention https://arxiv.org/abs/2112.05682 | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
Parameters `dtype` | |
""" | |
in_channels: int | |
dropout: float = 0.0 | |
num_layers: int = 1 | |
num_attention_heads: int = 1 | |
use_linear_projection: bool = False | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
transformer_layers_per_block: int = 1 | |
def setup(self): | |
# there is always at least one resnet | |
resnets = [ | |
FlaxResnetBlock2D( | |
in_channels=self.in_channels, | |
out_channels=self.in_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
] | |
attentions = [] | |
for _ in range(self.num_layers): | |
attn_block = FlaxTransformer2DModel( | |
in_channels=self.in_channels, | |
n_heads=self.num_attention_heads, | |
d_head=self.in_channels // self.num_attention_heads, | |
depth=self.transformer_layers_per_block, | |
use_linear_projection=self.use_linear_projection, | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
dtype=self.dtype, | |
) | |
attentions.append(attn_block) | |
res_block = FlaxResnetBlock2D( | |
in_channels=self.in_channels, | |
out_channels=self.in_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) | |
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) | |
return hidden_states | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import BaseOutput | |
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block | |
@dataclass | |
class UNet2DOutput(BaseOutput): | |
""" | |
The output of [`UNet2DModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
The hidden states output from the last layer of the model. | |
""" | |
sample: torch.FloatTensor | |
class UNet2DModel(ModelMixin, ConfigMixin): | |
r""" | |
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - | |
1)`. | |
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. | |
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. | |
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. | |
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. | |
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): | |
Whether to flip sin to cos for Fourier time embedding. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): | |
Tuple of downsample block types. | |
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): | |
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): | |
Tuple of upsample block types. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): | |
Tuple of block output channels. | |
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. | |
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. | |
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. | |
downsample_type (`str`, *optional*, defaults to `conv`): | |
The downsample type for downsampling layers. Choose between "conv" and "resnet" | |
upsample_type (`str`, *optional*, defaults to `conv`): | |
The upsample type for upsampling layers. Choose between "conv" and "resnet" | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. | |
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. | |
attn_norm_num_groups (`int`, *optional*, defaults to `None`): | |
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the | |
given number of groups. If left as `None`, the group norm layer will only be created if | |
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. | |
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. | |
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config | |
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. | |
class_embed_type (`str`, *optional*, defaults to `None`): | |
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, | |
`"timestep"`, or `"identity"`. | |
num_class_embeds (`int`, *optional*, defaults to `None`): | |
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class | |
conditioning with `class_embed_type` equal to `None`. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
sample_size: Optional[Union[int, Tuple[int, int]]] = None, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
center_input_sample: bool = False, | |
time_embedding_type: str = "positional", | |
freq_shift: int = 0, | |
flip_sin_to_cos: bool = True, | |
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), | |
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), | |
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), | |
layers_per_block: int = 2, | |
mid_block_scale_factor: float = 1, | |
downsample_padding: int = 1, | |
downsample_type: str = "conv", | |
upsample_type: str = "conv", | |
dropout: float = 0.0, | |
act_fn: str = "silu", | |
attention_head_dim: Optional[int] = 8, | |
norm_num_groups: int = 32, | |
attn_norm_num_groups: Optional[int] = None, | |
norm_eps: float = 1e-5, | |
resnet_time_scale_shift: str = "default", | |
add_attention: bool = True, | |
class_embed_type: Optional[str] = None, | |
num_class_embeds: Optional[int] = None, | |
num_train_timesteps: Optional[int] = None, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
time_embed_dim = block_out_channels[0] * 4 | |
# Check inputs | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
# input | |
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) | |
# time | |
if time_embedding_type == "fourier": | |
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) | |
timestep_input_dim = 2 * block_out_channels[0] | |
elif time_embedding_type == "positional": | |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
timestep_input_dim = block_out_channels[0] | |
elif time_embedding_type == "learned": | |
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) | |
timestep_input_dim = block_out_channels[0] | |
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) | |
# class embedding | |
if class_embed_type is None and num_class_embeds is not None: | |
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) | |
elif class_embed_type == "timestep": | |
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) | |
elif class_embed_type == "identity": | |
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) | |
else: | |
self.class_embedding = None | |
self.down_blocks = nn.ModuleList([]) | |
self.mid_block = None | |
self.up_blocks = nn.ModuleList([]) | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
downsample_type=downsample_type, | |
dropout=dropout, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = UNetMidBlock2D( | |
in_channels=block_out_channels[-1], | |
temb_channels=time_embed_dim, | |
dropout=dropout, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
output_scale_factor=mid_block_scale_factor, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], | |
resnet_groups=norm_num_groups, | |
attn_groups=attn_norm_num_groups, | |
add_attention=add_attention, | |
) | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
is_final_block = i == len(block_out_channels) - 1 | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=layers_per_block + 1, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=time_embed_dim, | |
add_upsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
upsample_type=upsample_type, | |
dropout=dropout, | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) | |
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) | |
self.conv_act = nn.SiLU() | |
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
class_labels: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
) -> Union[UNet2DOutput, Tuple]: | |
r""" | |
The [`UNet2DModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.unet_2d.UNet2DOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is | |
returned where the first element is the sample tensor. | |
""" | |
# 0. center input if necessary | |
if self.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) | |
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=self.dtype) | |
emb = self.time_embedding(t_emb) | |
if self.class_embedding is not None: | |
if class_labels is None: | |
raise ValueError("class_labels should be provided when doing class conditioning") | |
if self.config.class_embed_type == "timestep": | |
class_labels = self.time_proj(class_labels) | |
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) | |
emb = emb + class_emb | |
elif self.class_embedding is None and class_labels is not None: | |
raise ValueError("class_embedding needs to be initialized in order to use class conditioning") | |
# 2. pre-process | |
skip_sample = sample | |
sample = self.conv_in(sample) | |
# 3. down | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "skip_conv"): | |
sample, res_samples, skip_sample = downsample_block( | |
hidden_states=sample, temb=emb, skip_sample=skip_sample | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
down_block_res_samples += res_samples | |
# 4. mid | |
sample = self.mid_block(sample, emb) | |
# 5. up | |
skip_sample = None | |
for upsample_block in self.up_blocks: | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
if hasattr(upsample_block, "skip_conv"): | |
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) | |
else: | |
sample = upsample_block(sample, res_samples, emb) | |
# 6. post-process | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
if skip_sample is not None: | |
sample += skip_sample | |
if self.config.time_embedding_type == "fourier": | |
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) | |
sample = sample / timesteps | |
if not return_dict: | |
return (sample,) | |
return UNet2DOutput(sample=sample) | |
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import UNet2DConditionLoadersMixin | |
from ...utils import logging | |
from ..attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
Attention, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from ..embeddings import TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
from ..transformers.transformer_temporal import TransformerTemporalModel | |
from .unet_2d_blocks import UNetMidBlock2DCrossAttn | |
from .unet_2d_condition import UNet2DConditionModel | |
from .unet_3d_blocks import ( | |
CrossAttnDownBlockMotion, | |
CrossAttnUpBlockMotion, | |
DownBlockMotion, | |
UNetMidBlockCrossAttnMotion, | |
UpBlockMotion, | |
get_down_block, | |
get_up_block, | |
) | |
from .unet_3d_condition import UNet3DConditionOutput | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class MotionModules(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
layers_per_block: int = 2, | |
num_attention_heads: int = 8, | |
attention_bias: bool = False, | |
cross_attention_dim: Optional[int] = None, | |
activation_fn: str = "geglu", | |
norm_num_groups: int = 32, | |
max_seq_length: int = 32, | |
): | |
super().__init__() | |
self.motion_modules = nn.ModuleList([]) | |
for i in range(layers_per_block): | |
self.motion_modules.append( | |
TransformerTemporalModel( | |
in_channels=in_channels, | |
norm_num_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
activation_fn=activation_fn, | |
attention_bias=attention_bias, | |
num_attention_heads=num_attention_heads, | |
attention_head_dim=in_channels // num_attention_heads, | |
positional_embeddings="sinusoidal", | |
num_positional_embeddings=max_seq_length, | |
) | |
) | |
class MotionAdapter(ModelMixin, ConfigMixin): | |
@register_to_config | |
def __init__( | |
self, | |
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), | |
motion_layers_per_block: int = 2, | |
motion_mid_block_layers_per_block: int = 1, | |
motion_num_attention_heads: int = 8, | |
motion_norm_num_groups: int = 32, | |
motion_max_seq_length: int = 32, | |
use_motion_mid_block: bool = True, | |
conv_in_channels: Optional[int] = None, | |
): | |
"""Container to store AnimateDiff Motion Modules | |
Args: | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each UNet block. | |
motion_layers_per_block (`int`, *optional*, defaults to 2): | |
The number of motion layers per UNet block. | |
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): | |
The number of motion layers in the middle UNet block. | |
motion_num_attention_heads (`int`, *optional*, defaults to 8): | |
The number of heads to use in each attention layer of the motion module. | |
motion_norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups to use in each group normalization layer of the motion module. | |
motion_max_seq_length (`int`, *optional*, defaults to 32): | |
The maximum sequence length to use in the motion module. | |
use_motion_mid_block (`bool`, *optional*, defaults to True): | |
Whether to use a motion module in the middle of the UNet. | |
""" | |
super().__init__() | |
down_blocks = [] | |
up_blocks = [] | |
if conv_in_channels: | |
# input | |
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) | |
else: | |
self.conv_in = None | |
for i, channel in enumerate(block_out_channels): | |
output_channel = block_out_channels[i] | |
down_blocks.append( | |
MotionModules( | |
in_channels=output_channel, | |
norm_num_groups=motion_norm_num_groups, | |
cross_attention_dim=None, | |
activation_fn="geglu", | |
attention_bias=False, | |
num_attention_heads=motion_num_attention_heads, | |
max_seq_length=motion_max_seq_length, | |
layers_per_block=motion_layers_per_block, | |
) | |
) | |
if use_motion_mid_block: | |
self.mid_block = MotionModules( | |
in_channels=block_out_channels[-1], | |
norm_num_groups=motion_norm_num_groups, | |
cross_attention_dim=None, | |
activation_fn="geglu", | |
attention_bias=False, | |
num_attention_heads=motion_num_attention_heads, | |
layers_per_block=motion_mid_block_layers_per_block, | |
max_seq_length=motion_max_seq_length, | |
) | |
else: | |
self.mid_block = None | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
for i, channel in enumerate(reversed_block_out_channels): | |
output_channel = reversed_block_out_channels[i] | |
up_blocks.append( | |
MotionModules( | |
in_channels=output_channel, | |
norm_num_groups=motion_norm_num_groups, | |
cross_attention_dim=None, | |
activation_fn="geglu", | |
attention_bias=False, | |
num_attention_heads=motion_num_attention_heads, | |
max_seq_length=motion_max_seq_length, | |
layers_per_block=motion_layers_per_block + 1, | |
) | |
) | |
self.down_blocks = nn.ModuleList(down_blocks) | |
self.up_blocks = nn.ModuleList(up_blocks) | |
def forward(self, sample): | |
pass | |
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): | |
r""" | |
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a | |
sample shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
""" | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 4, | |
out_channels: int = 4, | |
down_block_types: Tuple[str, ...] = ( | |
"CrossAttnDownBlockMotion", | |
"CrossAttnDownBlockMotion", | |
"CrossAttnDownBlockMotion", | |
"DownBlockMotion", | |
), | |
up_block_types: Tuple[str, ...] = ( | |
"UpBlockMotion", | |
"CrossAttnUpBlockMotion", | |
"CrossAttnUpBlockMotion", | |
"CrossAttnUpBlockMotion", | |
), | |
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), | |
layers_per_block: int = 2, | |
downsample_padding: int = 1, | |
mid_block_scale_factor: float = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-5, | |
cross_attention_dim: int = 1280, | |
use_linear_projection: bool = False, | |
num_attention_heads: Union[int, Tuple[int, ...]] = 8, | |
motion_max_seq_length: int = 32, | |
motion_num_attention_heads: int = 8, | |
use_motion_mid_block: int = True, | |
encoder_hid_dim: Optional[int] = None, | |
encoder_hid_dim_type: Optional[str] = None, | |
time_cond_proj_dim: Optional[int] = None, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
# Check inputs | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
# input | |
conv_in_kernel = 3 | |
conv_out_kernel = 3 | |
conv_in_padding = (conv_in_kernel - 1) // 2 | |
self.conv_in = nn.Conv2d( | |
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding | |
) | |
# time | |
time_embed_dim = block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], True, 0) | |
timestep_input_dim = block_out_channels[0] | |
self.time_embedding = TimestepEmbedding( | |
timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim | |
) | |
if encoder_hid_dim_type is None: | |
self.encoder_hid_proj = None | |
# class embedding | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads[i], | |
downsample_padding=downsample_padding, | |
use_linear_projection=use_linear_projection, | |
dual_cross_attention=False, | |
temporal_num_attention_heads=motion_num_attention_heads, | |
temporal_max_seq_length=motion_max_seq_length, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
if use_motion_mid_block: | |
self.mid_block = UNetMidBlockCrossAttnMotion( | |
in_channels=block_out_channels[-1], | |
temb_channels=time_embed_dim, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
output_scale_factor=mid_block_scale_factor, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads[-1], | |
resnet_groups=norm_num_groups, | |
dual_cross_attention=False, | |
use_linear_projection=use_linear_projection, | |
temporal_num_attention_heads=motion_num_attention_heads, | |
temporal_max_seq_length=motion_max_seq_length, | |
) | |
else: | |
self.mid_block = UNetMidBlock2DCrossAttn( | |
in_channels=block_out_channels[-1], | |
temb_channels=time_embed_dim, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
output_scale_factor=mid_block_scale_factor, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads[-1], | |
resnet_groups=norm_num_groups, | |
dual_cross_attention=False, | |
use_linear_projection=use_linear_projection, | |
) | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=layers_per_block + 1, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=reversed_num_attention_heads[i], | |
dual_cross_attention=False, | |
resolution_idx=i, | |
use_linear_projection=use_linear_projection, | |
temporal_num_attention_heads=motion_num_attention_heads, | |
temporal_max_seq_length=motion_max_seq_length, | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
if norm_num_groups is not None: | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
) | |
self.conv_act = nn.SiLU() | |
else: | |
self.conv_norm_out = None | |
self.conv_act = None | |
conv_out_padding = (conv_out_kernel - 1) // 2 | |
self.conv_out = nn.Conv2d( | |
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding | |
) | |
@classmethod | |
def from_unet2d( | |
cls, | |
unet: UNet2DConditionModel, | |
motion_adapter: Optional[MotionAdapter] = None, | |
load_weights: bool = True, | |
): | |
has_motion_adapter = motion_adapter is not None | |
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 | |
config = unet.config | |
config["_class_name"] = cls.__name__ | |
down_blocks = [] | |
for down_blocks_type in config["down_block_types"]: | |
if "CrossAttn" in down_blocks_type: | |
down_blocks.append("CrossAttnDownBlockMotion") | |
else: | |
down_blocks.append("DownBlockMotion") | |
config["down_block_types"] = down_blocks | |
up_blocks = [] | |
for down_blocks_type in config["up_block_types"]: | |
if "CrossAttn" in down_blocks_type: | |
up_blocks.append("CrossAttnUpBlockMotion") | |
else: | |
up_blocks.append("UpBlockMotion") | |
config["up_block_types"] = up_blocks | |
if has_motion_adapter: | |
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] | |
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] | |
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] | |
# For PIA UNets we need to set the number input channels to 9 | |
if motion_adapter.config["conv_in_channels"]: | |
config["in_channels"] = motion_adapter.config["conv_in_channels"] | |
# Need this for backwards compatibility with UNet2DConditionModel checkpoints | |
if not config.get("num_attention_heads"): | |
config["num_attention_heads"] = config["attention_head_dim"] | |
model = cls.from_config(config) | |
if not load_weights: | |
return model | |
# Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight | |
# while the last 5 channels must be PIA conv_in weights. | |
if has_motion_adapter and motion_adapter.config["conv_in_channels"]: | |
model.conv_in = motion_adapter.conv_in | |
updated_conv_in_weight = torch.cat( | |
[unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 | |
) | |
model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) | |
else: | |
model.conv_in.load_state_dict(unet.conv_in.state_dict()) | |
model.time_proj.load_state_dict(unet.time_proj.state_dict()) | |
model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) | |
for i, down_block in enumerate(unet.down_blocks): | |
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) | |
if hasattr(model.down_blocks[i], "attentions"): | |
model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) | |
if model.down_blocks[i].downsamplers: | |
model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) | |
for i, up_block in enumerate(unet.up_blocks): | |
model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) | |
if hasattr(model.up_blocks[i], "attentions"): | |
model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) | |
if model.up_blocks[i].upsamplers: | |
model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) | |
model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) | |
model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) | |
if unet.conv_norm_out is not None: | |
model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) | |
if unet.conv_act is not None: | |
model.conv_act.load_state_dict(unet.conv_act.state_dict()) | |
model.conv_out.load_state_dict(unet.conv_out.state_dict()) | |
if has_motion_adapter: | |
model.load_motion_modules(motion_adapter) | |
# ensure that the Motion UNet is the same dtype as the UNet2DConditionModel | |
model.to(unet.dtype) | |
return model | |
def freeze_unet2d_params(self) -> None: | |
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules | |
unfrozen for fine tuning. | |
""" | |
# Freeze everything | |
for param in self.parameters(): | |
param.requires_grad = False | |
# Unfreeze Motion Modules | |
for down_block in self.down_blocks: | |
motion_modules = down_block.motion_modules | |
for param in motion_modules.parameters(): | |
param.requires_grad = True | |
for up_block in self.up_blocks: | |
motion_modules = up_block.motion_modules | |
for param in motion_modules.parameters(): | |
param.requires_grad = True | |
if hasattr(self.mid_block, "motion_modules"): | |
motion_modules = self.mid_block.motion_modules | |
for param in motion_modules.parameters(): | |
param.requires_grad = True | |
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: | |
for i, down_block in enumerate(motion_adapter.down_blocks): | |
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) | |
for i, up_block in enumerate(motion_adapter.up_blocks): | |
self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) | |
# to support older motion modules that don't have a mid_block | |
if hasattr(self.mid_block, "motion_modules"): | |
self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) | |
def save_motion_modules( | |
self, | |
save_directory: str, | |
is_main_process: bool = True, | |
safe_serialization: bool = True, | |
variant: Optional[str] = None, | |
push_to_hub: bool = False, | |
**kwargs, | |
) -> None: | |
state_dict = self.state_dict() | |
# Extract all motion modules | |
motion_state_dict = {} | |
for k, v in state_dict.items(): | |
if "motion_modules" in k: | |
motion_state_dict[k] = v | |
adapter = MotionAdapter( | |
block_out_channels=self.config["block_out_channels"], | |
motion_layers_per_block=self.config["layers_per_block"], | |
motion_norm_num_groups=self.config["norm_num_groups"], | |
motion_num_attention_heads=self.config["motion_num_attention_heads"], | |
motion_max_seq_length=self.config["motion_max_seq_length"], | |
use_motion_mid_block=self.config["use_motion_mid_block"], | |
) | |
adapter.load_state_dict(motion_state_dict) | |
adapter.save_pretrained( | |
save_directory=save_directory, | |
is_main_process=is_main_process, | |
safe_serialization=safe_serialization, | |
variant=variant, | |
push_to_hub=push_to_hub, | |
**kwargs, | |
) | |
@property | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | |
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: | |
""" | |
Sets the attention processor to use [feed forward | |
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). | |
Parameters: | |
chunk_size (`int`, *optional*): | |
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually | |
over each tensor of dim=`dim`. | |
dim (`int`, *optional*, defaults to `0`): | |
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) | |
or dim=1 (sequence length). | |
""" | |
if dim not in [0, 1]: | |
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") | |
# By default chunk size is 1 | |
chunk_size = chunk_size or 1 | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, chunk_size, dim) | |
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking | |
def disable_forward_chunking(self) -> None: | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, None, 0) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self) -> None: | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: | |
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): | |
module.gradient_checkpointing = value | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu | |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: | |
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. | |
The suffixes after the scaling factors represent the stage blocks where they are being applied. | |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that | |
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. | |
Args: | |
s1 (`float`): | |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
s2 (`float`): | |
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. | |
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. | |
""" | |
for i, upsample_block in enumerate(self.up_blocks): | |
setattr(upsample_block, "s1", s1) | |
setattr(upsample_block, "s2", s2) | |
setattr(upsample_block, "b1", b1) | |
setattr(upsample_block, "b2", b2) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu | |
def disable_freeu(self) -> None: | |
"""Disables the FreeU mechanism.""" | |
freeu_keys = {"s1", "s2", "b1", "b2"} | |
for i, upsample_block in enumerate(self.up_blocks): | |
for k in freeu_keys: | |
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |
setattr(upsample_block, k, None) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, | |
key, value) are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
mid_block_additional_residual: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: | |
r""" | |
The [`UNetMotionModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.FloatTensor`): | |
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
through the `self.time_embedding` layer to obtain the timestep embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
A tensor that if specified is added to the residual of the middle unet block. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise | |
a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
logger.info("Forward upsample size to force interpolation output size.") | |
forward_upsample_size = True | |
# prepare attention_mask | |
if attention_mask is not None: | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
num_frames = sample.shape[2] | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=self.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
emb = emb.repeat_interleave(repeats=num_frames, dim=0) | |
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) | |
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
image_embeds = self.encoder_hid_proj(image_embeds) | |
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] | |
encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
# 2. pre-process | |
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) | |
sample = self.conv_in(sample) | |
# 3. down | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) | |
down_block_res_samples += res_samples | |
if down_block_additional_residuals is not None: | |
new_down_block_res_samples = () | |
for down_block_res_sample, down_block_additional_residual in zip( | |
down_block_res_samples, down_block_additional_residuals | |
): | |
down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
new_down_block_res_samples += (down_block_res_sample,) | |
down_block_res_samples = new_down_block_res_samples | |
# 4. mid | |
if self.mid_block is not None: | |
# To support older versions of motion modules that don't have a mid_block | |
if hasattr(self.mid_block, "motion_modules"): | |
sample = self.mid_block( | |
sample, | |
emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
else: | |
sample = self.mid_block( | |
sample, | |
emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
if mid_block_additional_residual is not None: | |
sample = sample + mid_block_additional_residual | |
# 5. up | |
for i, upsample_block in enumerate(self.up_blocks): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
num_frames=num_frames, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
upsample_size=upsample_size, | |
num_frames=num_frames, | |
) | |
# 6. post-process | |
if self.conv_norm_out: | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
# reshape to (batch, channel, framerate, width, height) | |
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) | |
if not return_dict: | |
return (sample,) | |
return UNet3DConditionOutput(sample=sample) | |
import math | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..activations import get_activation | |
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims | |
class DownResnetBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
conv_shortcut: bool = False, | |
temb_channels: int = 32, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
non_linearity: Optional[str] = None, | |
time_embedding_norm: str = "default", | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.time_embedding_norm = time_embedding_norm | |
self.add_downsample = add_downsample | |
self.output_scale_factor = output_scale_factor | |
if groups_out is None: | |
groups_out = groups | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.downsample = None | |
if add_downsample: | |
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
output_states = () | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
output_states += (hidden_states,) | |
if self.nonlinearity is not None: | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.downsample is not None: | |
hidden_states = self.downsample(hidden_states) | |
return hidden_states, output_states | |
class UpResnetBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
temb_channels: int = 32, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
non_linearity: Optional[str] = None, | |
time_embedding_norm: str = "default", | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.time_embedding_norm = time_embedding_norm | |
self.add_upsample = add_upsample | |
self.output_scale_factor = output_scale_factor | |
if groups_out is None: | |
groups_out = groups | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = None | |
if add_upsample: | |
self.upsample = Upsample1D(out_channels, use_conv_transpose=True) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
if res_hidden_states_tuple is not None: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
if self.nonlinearity is not None: | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.upsample is not None: | |
hidden_states = self.upsample(hidden_states) | |
return hidden_states | |
class ValueFunctionMidBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, embed_dim: int): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.embed_dim = embed_dim | |
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) | |
self.down1 = Downsample1D(out_channels // 2, use_conv=True) | |
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) | |
self.down2 = Downsample1D(out_channels // 4, use_conv=True) | |
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
x = self.res1(x, temb) | |
x = self.down1(x) | |
x = self.res2(x, temb) | |
x = self.down2(x) | |
return x | |
class MidResTemporalBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
num_layers: int = 1, | |
add_downsample: bool = False, | |
add_upsample: bool = False, | |
non_linearity: Optional[str] = None, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.add_downsample = add_downsample | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = None | |
if add_upsample: | |
self.upsample = Downsample1D(out_channels, use_conv=True) | |
self.downsample = None | |
if add_downsample: | |
self.downsample = Downsample1D(out_channels, use_conv=True) | |
if self.upsample and self.downsample: | |
raise ValueError("Block cannot downsample and upsample") | |
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
if self.upsample: | |
hidden_states = self.upsample(hidden_states) | |
if self.downsample: | |
self.downsample = self.downsample(hidden_states) | |
return hidden_states | |
class OutConv1DBlock(nn.Module): | |
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): | |
super().__init__() | |
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) | |
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) | |
self.final_conv1d_act = get_activation(act_fn) | |
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.final_conv1d_1(hidden_states) | |
hidden_states = rearrange_dims(hidden_states) | |
hidden_states = self.final_conv1d_gn(hidden_states) | |
hidden_states = rearrange_dims(hidden_states) | |
hidden_states = self.final_conv1d_act(hidden_states) | |
hidden_states = self.final_conv1d_2(hidden_states) | |
return hidden_states | |
class OutValueFunctionBlock(nn.Module): | |
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): | |
super().__init__() | |
self.final_block = nn.ModuleList( | |
[ | |
nn.Linear(fc_dim + embed_dim, fc_dim // 2), | |
get_activation(act_fn), | |
nn.Linear(fc_dim // 2, 1), | |
] | |
) | |
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = hidden_states.view(hidden_states.shape[0], -1) | |
hidden_states = torch.cat((hidden_states, temb), dim=-1) | |
for layer in self.final_block: | |
hidden_states = layer(hidden_states) | |
return hidden_states | |
_kernels = { | |
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], | |
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], | |
"lanczos3": [ | |
0.003689131001010537, | |
0.015056144446134567, | |
-0.03399861603975296, | |
-0.066637322306633, | |
0.13550527393817902, | |
0.44638532400131226, | |
0.44638532400131226, | |
0.13550527393817902, | |
-0.066637322306633, | |
-0.03399861603975296, | |
0.015056144446134567, | |
0.003689131001010537, | |
], | |
} | |
class Downsample1d(nn.Module): | |
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor(_kernels[kernel]) | |
self.pad = kernel_1d.shape[0] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d) | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) | |
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) | |
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) | |
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) | |
weight[indices, indices] = kernel | |
return F.conv1d(hidden_states, weight, stride=2) | |
class Upsample1d(nn.Module): | |
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor(_kernels[kernel]) * 2 | |
self.pad = kernel_1d.shape[0] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) | |
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) | |
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) | |
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) | |
weight[indices, indices] = kernel | |
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) | |
class SelfAttention1d(nn.Module): | |
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): | |
super().__init__() | |
self.channels = in_channels | |
self.group_norm = nn.GroupNorm(1, num_channels=in_channels) | |
self.num_heads = n_head | |
self.query = nn.Linear(self.channels, self.channels) | |
self.key = nn.Linear(self.channels, self.channels) | |
self.value = nn.Linear(self.channels, self.channels) | |
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) | |
self.dropout = nn.Dropout(dropout_rate, inplace=True) | |
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: | |
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) | |
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) | |
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) | |
return new_projection | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
residual = hidden_states | |
batch, channel_dim, seq = hidden_states.shape | |
hidden_states = self.group_norm(hidden_states) | |
hidden_states = hidden_states.transpose(1, 2) | |
query_proj = self.query(hidden_states) | |
key_proj = self.key(hidden_states) | |
value_proj = self.value(hidden_states) | |
query_states = self.transpose_for_scores(query_proj) | |
key_states = self.transpose_for_scores(key_proj) | |
value_states = self.transpose_for_scores(value_proj) | |
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) | |
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) | |
attention_probs = torch.softmax(attention_scores, dim=-1) | |
# compute attention output | |
hidden_states = torch.matmul(attention_probs, value_states) | |
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() | |
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) | |
hidden_states = hidden_states.view(new_hidden_states_shape) | |
# compute next hidden_states | |
hidden_states = self.proj_attn(hidden_states) | |
hidden_states = hidden_states.transpose(1, 2) | |
hidden_states = self.dropout(hidden_states) | |
output = hidden_states + residual | |
return output | |
class ResConvBlock(nn.Module): | |
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): | |
super().__init__() | |
self.is_last = is_last | |
self.has_conv_skip = in_channels != out_channels | |
if self.has_conv_skip: | |
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) | |
self.group_norm_1 = nn.GroupNorm(1, mid_channels) | |
self.gelu_1 = nn.GELU() | |
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) | |
if not self.is_last: | |
self.group_norm_2 = nn.GroupNorm(1, out_channels) | |
self.gelu_2 = nn.GELU() | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states | |
hidden_states = self.conv_1(hidden_states) | |
hidden_states = self.group_norm_1(hidden_states) | |
hidden_states = self.gelu_1(hidden_states) | |
hidden_states = self.conv_2(hidden_states) | |
if not self.is_last: | |
hidden_states = self.group_norm_2(hidden_states) | |
hidden_states = self.gelu_2(hidden_states) | |
output = hidden_states + residual | |
return output | |
class UNetMidBlock1D(nn.Module): | |
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): | |
super().__init__() | |
out_channels = in_channels if out_channels is None else out_channels | |
# there is always at least one resnet | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.up = Upsample1d(kernel="cubic") | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for attn, resnet in zip(self.attentions, self.resnets): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class AttnDownBlock1D(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
return hidden_states, (hidden_states,) | |
class DownBlock1D(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states, (hidden_states,) | |
class DownBlock1DNoSkip(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = torch.cat([hidden_states, temb], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states, (hidden_states,) | |
class AttnUpBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.up = Upsample1d(kernel="cubic") | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class UpBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = in_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
self.up = Upsample1d(kernel="cubic") | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class UpBlock1DNoSkip(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = in_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states | |
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] | |
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] | |
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] | |
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] | |
def get_down_block( | |
down_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
add_downsample: bool, | |
) -> DownBlockType: | |
if down_block_type == "DownResnetBlock1D": | |
return DownResnetBlock1D( | |
in_channels=in_channels, | |
num_layers=num_layers, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
) | |
elif down_block_type == "DownBlock1D": | |
return DownBlock1D(out_channels=out_channels, in_channels=in_channels) | |
elif down_block_type == "AttnDownBlock1D": | |
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) | |
elif down_block_type == "DownBlock1DNoSkip": | |
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) | |
raise ValueError(f"{down_block_type} does not exist.") | |
def get_up_block( | |
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool | |
) -> UpBlockType: | |
if up_block_type == "UpResnetBlock1D": | |
return UpResnetBlock1D( | |
in_channels=in_channels, | |
num_layers=num_layers, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "UpBlock1D": | |
return UpBlock1D(in_channels=in_channels, out_channels=out_channels) | |
elif up_block_type == "AttnUpBlock1D": | |
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) | |
elif up_block_type == "UpBlock1DNoSkip": | |
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) | |
raise ValueError(f"{up_block_type} does not exist.") | |
def get_mid_block( | |
mid_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
mid_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
add_downsample: bool, | |
) -> MidBlockType: | |
if mid_block_type == "MidResTemporalBlock1D": | |
return MidResTemporalBlock1D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
embed_dim=embed_dim, | |
add_downsample=add_downsample, | |
) | |
elif mid_block_type == "ValueFunctionMidBlock1D": | |
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) | |
elif mid_block_type == "UNetMidBlock1D": | |
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) | |
raise ValueError(f"{mid_block_type} does not exist.") | |
def get_out_block( | |
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int | |
) -> Optional[OutBlockType]: | |
if out_block_type == "OutConv1DBlock": | |
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) | |
elif out_block_type == "ValueFunction": | |
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) | |
return None | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.utils.checkpoint import checkpoint | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import PeftAdapterMixin | |
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock | |
from ..attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from ..embeddings import TimestepEmbedding, get_timestep_embedding | |
from ..modeling_utils import ModelMixin | |
from ..normalization import GlobalResponseNorm, RMSNorm | |
from ..resnet import Downsample2D, Upsample2D | |
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): | |
_supports_gradient_checkpointing = True | |
@register_to_config | |
def __init__( | |
self, | |
# global config | |
hidden_size: int = 1024, | |
use_bias: bool = False, | |
hidden_dropout: float = 0.0, | |
# conditioning dimensions | |
cond_embed_dim: int = 768, | |
micro_cond_encode_dim: int = 256, | |
micro_cond_embed_dim: int = 1280, | |
encoder_hidden_size: int = 768, | |
# num tokens | |
vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded | |
codebook_size: int = 8192, | |
# `UVit2DConvEmbed` | |
in_channels: int = 768, | |
block_out_channels: int = 768, | |
num_res_blocks: int = 3, | |
downsample: bool = False, | |
upsample: bool = False, | |
block_num_heads: int = 12, | |
# `TransformerLayer` | |
num_hidden_layers: int = 22, | |
num_attention_heads: int = 16, | |
# `Attention` | |
attention_dropout: float = 0.0, | |
# `FeedForward` | |
intermediate_size: int = 2816, | |
# `Norm` | |
layer_norm_eps: float = 1e-6, | |
ln_elementwise_affine: bool = True, | |
sample_size: int = 64, | |
): | |
super().__init__() | |
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) | |
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) | |
self.embed = UVit2DConvEmbed( | |
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias | |
) | |
self.cond_embed = TimestepEmbedding( | |
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias | |
) | |
self.down_block = UVitBlock( | |
block_out_channels, | |
num_res_blocks, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample, | |
False, | |
) | |
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) | |
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) | |
self.transformer_layers = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
dim=hidden_size, | |
num_attention_heads=num_attention_heads, | |
attention_head_dim=hidden_size // num_attention_heads, | |
dropout=hidden_dropout, | |
cross_attention_dim=hidden_size, | |
attention_bias=use_bias, | |
norm_type="ada_norm_continuous", | |
ada_norm_continous_conditioning_embedding_dim=hidden_size, | |
norm_elementwise_affine=ln_elementwise_affine, | |
norm_eps=layer_norm_eps, | |
ada_norm_bias=use_bias, | |
ff_inner_dim=intermediate_size, | |
ff_bias=use_bias, | |
attention_out_bias=use_bias, | |
) | |
for _ in range(num_hidden_layers) | |
] | |
) | |
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) | |
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) | |
self.up_block = UVitBlock( | |
block_out_channels, | |
num_res_blocks, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample=False, | |
upsample=upsample, | |
) | |
self.mlm_layer = ConvMlmLayer( | |
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size | |
) | |
self.gradient_checkpointing = False | |
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: | |
pass | |
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): | |
encoder_hidden_states = self.encoder_proj(encoder_hidden_states) | |
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) | |
micro_cond_embeds = get_timestep_embedding( | |
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 | |
) | |
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) | |
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) | |
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) | |
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) | |
hidden_states = self.embed(input_ids) | |
hidden_states = self.down_block( | |
hidden_states, | |
pooled_text_emb=pooled_text_emb, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
batch_size, channels, height, width = hidden_states.shape | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) | |
hidden_states = self.project_to_hidden_norm(hidden_states) | |
hidden_states = self.project_to_hidden(hidden_states) | |
for layer in self.transformer_layers: | |
if self.training and self.gradient_checkpointing: | |
def layer_(*args): | |
return checkpoint(layer, *args) | |
else: | |
layer_ = layer | |
hidden_states = layer_( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, | |
) | |
hidden_states = self.project_from_hidden_norm(hidden_states) | |
hidden_states = self.project_from_hidden(hidden_states) | |
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) | |
hidden_states = self.up_block( | |
hidden_states, | |
pooled_text_emb=pooled_text_emb, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
logits = self.mlm_layer(hidden_states) | |
return logits | |
@property | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
class UVit2DConvEmbed(nn.Module): | |
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): | |
super().__init__() | |
self.embeddings = nn.Embedding(vocab_size, in_channels) | |
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) | |
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) | |
def forward(self, input_ids): | |
embeddings = self.embeddings(input_ids) | |
embeddings = self.layer_norm(embeddings) | |
embeddings = embeddings.permute(0, 3, 1, 2) | |
embeddings = self.conv(embeddings) | |
return embeddings | |
class UVitBlock(nn.Module): | |
def __init__( | |
self, | |
channels, | |
num_res_blocks: int, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample: bool, | |
upsample: bool, | |
): | |
super().__init__() | |
if downsample: | |
self.downsample = Downsample2D( | |
channels, | |
use_conv=True, | |
padding=0, | |
name="Conv2d_0", | |
kernel_size=2, | |
norm_type="rms_norm", | |
eps=layer_norm_eps, | |
elementwise_affine=ln_elementwise_affine, | |
bias=use_bias, | |
) | |
else: | |
self.downsample = None | |
self.res_blocks = nn.ModuleList( | |
[ | |
ConvNextBlock( | |
channels, | |
layer_norm_eps, | |
ln_elementwise_affine, | |
use_bias, | |
hidden_dropout, | |
hidden_size, | |
) | |
for i in range(num_res_blocks) | |
] | |
) | |
self.attention_blocks = nn.ModuleList( | |
[ | |
SkipFFTransformerBlock( | |
channels, | |
block_num_heads, | |
channels // block_num_heads, | |
hidden_size, | |
use_bias, | |
attention_dropout, | |
channels, | |
attention_bias=use_bias, | |
attention_out_bias=use_bias, | |
) | |
for _ in range(num_res_blocks) | |
] | |
) | |
if upsample: | |
self.upsample = Upsample2D( | |
channels, | |
use_conv_transpose=True, | |
kernel_size=2, | |
padding=0, | |
name="conv", | |
norm_type="rms_norm", | |
eps=layer_norm_eps, | |
elementwise_affine=ln_elementwise_affine, | |
bias=use_bias, | |
interpolate=False, | |
) | |
else: | |
self.upsample = None | |
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): | |
if self.downsample is not None: | |
x = self.downsample(x) | |
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): | |
x = res_block(x, pooled_text_emb) | |
batch_size, channels, height, width = x.shape | |
x = x.view(batch_size, channels, height * width).permute(0, 2, 1) | |
x = attention_block( | |
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs | |
) | |
x = x.permute(0, 2, 1).view(batch_size, channels, height, width) | |
if self.upsample is not None: | |
x = self.upsample(x) | |
return x | |
class ConvNextBlock(nn.Module): | |
def __init__( | |
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 | |
): | |
super().__init__() | |
self.depthwise = nn.Conv2d( | |
channels, | |
channels, | |
kernel_size=3, | |
padding=1, | |
groups=channels, | |
bias=use_bias, | |
) | |
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) | |
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) | |
self.channelwise_act = nn.GELU() | |
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) | |
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) | |
self.channelwise_dropout = nn.Dropout(hidden_dropout) | |
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) | |
def forward(self, x, cond_embeds): | |
x_res = x | |
x = self.depthwise(x) | |
x = x.permute(0, 2, 3, 1) | |
x = self.norm(x) | |
x = self.channelwise_linear_1(x) | |
x = self.channelwise_act(x) | |
x = self.channelwise_norm(x) | |
x = self.channelwise_linear_2(x) | |
x = self.channelwise_dropout(x) | |
x = x.permute(0, 3, 1, 2) | |
x = x + x_res | |
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) | |
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] | |
return x | |
class ConvMlmLayer(nn.Module): | |
def __init__( | |
self, | |
block_out_channels: int, | |
in_channels: int, | |
use_bias: bool, | |
ln_elementwise_affine: bool, | |
layer_norm_eps: float, | |
codebook_size: int, | |
): | |
super().__init__() | |
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) | |
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) | |
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) | |
def forward(self, hidden_states): | |
hidden_states = self.conv1(hidden_states) | |
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
logits = self.conv2(hidden_states) | |
return logits | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import BaseOutput | |
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps | |
from ..modeling_utils import ModelMixin | |
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block | |
@dataclass | |
class UNet1DOutput(BaseOutput): | |
""" | |
The output of [`UNet1DModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): | |
The hidden states output from the last layer of the model. | |
""" | |
sample: torch.FloatTensor | |
class UNet1DModel(ModelMixin, ConfigMixin): | |
r""" | |
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. | |
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. | |
extra_in_channels (`int`, *optional*, defaults to 0): | |
Number of additional channels to be added to the input of the first down block. Useful for cases where the | |
input data has more channels than what the model was initially designed for. | |
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. | |
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. | |
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): | |
Whether to flip sin to cos for Fourier time embedding. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): | |
Tuple of downsample block types. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): | |
Tuple of upsample block types. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): | |
Tuple of block output channels. | |
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. | |
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. | |
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. | |
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. | |
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. | |
downsample_each_block (`int`, *optional*, defaults to `False`): | |
Experimental feature for using a UNet without upsampling. | |
""" | |
@register_to_config | |
def __init__( | |
self, | |
sample_size: int = 65536, | |
sample_rate: Optional[int] = None, | |
in_channels: int = 2, | |
out_channels: int = 2, | |
extra_in_channels: int = 0, | |
time_embedding_type: str = "fourier", | |
flip_sin_to_cos: bool = True, | |
use_timestep_embedding: bool = False, | |
freq_shift: float = 0.0, | |
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), | |
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), | |
mid_block_type: Tuple[str] = "UNetMidBlock1D", | |
out_block_type: str = None, | |
block_out_channels: Tuple[int] = (32, 32, 64), | |
act_fn: str = None, | |
norm_num_groups: int = 8, | |
layers_per_block: int = 1, | |
downsample_each_block: bool = False, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
# time | |
if time_embedding_type == "fourier": | |
self.time_proj = GaussianFourierProjection( | |
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
timestep_input_dim = 2 * block_out_channels[0] | |
elif time_embedding_type == "positional": | |
self.time_proj = Timesteps( | |
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift | |
) | |
timestep_input_dim = block_out_channels[0] | |
if use_timestep_embedding: | |
time_embed_dim = block_out_channels[0] * 4 | |
self.time_mlp = TimestepEmbedding( | |
in_channels=timestep_input_dim, | |
time_embed_dim=time_embed_dim, | |
act_fn=act_fn, | |
out_dim=block_out_channels[0], | |
) | |
self.down_blocks = nn.ModuleList([]) | |
self.mid_block = None | |
self.up_blocks = nn.ModuleList([]) | |
self.out_block = None | |
# down | |
output_channel = in_channels | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
if i == 0: | |
input_channel += extra_in_channels | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=block_out_channels[0], | |
add_downsample=not is_final_block or downsample_each_block, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = get_mid_block( | |
mid_block_type, | |
in_channels=block_out_channels[-1], | |
mid_channels=block_out_channels[-1], | |
out_channels=block_out_channels[-1], | |
embed_dim=block_out_channels[0], | |
num_layers=layers_per_block, | |
add_downsample=downsample_each_block, | |
) | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
if out_block_type is None: | |
final_upsample_channels = out_channels | |
else: | |
final_upsample_channels = block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
prev_output_channel = output_channel | |
output_channel = ( | |
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels | |
) | |
is_final_block = i == len(block_out_channels) - 1 | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=layers_per_block, | |
in_channels=prev_output_channel, | |
out_channels=output_channel, | |
temb_channels=block_out_channels[0], | |
add_upsample=not is_final_block, | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) | |
self.out_block = get_out_block( | |
out_block_type=out_block_type, | |
num_groups_out=num_groups_out, | |
embed_dim=block_out_channels[0], | |
out_channels=out_channels, | |
act_fn=act_fn, | |
fc_dim=block_out_channels[-1] // 4, | |
) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
return_dict: bool = True, | |
) -> Union[UNet1DOutput, Tuple]: | |
r""" | |
The [`UNet1DModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.unet_1d.UNet1DOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is | |
returned where the first element is the sample tensor. | |
""" | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) | |
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
timestep_embed = self.time_proj(timesteps) | |
if self.config.use_timestep_embedding: | |
timestep_embed = self.time_mlp(timestep_embed) | |
else: | |
timestep_embed = timestep_embed[..., None] | |
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) | |
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) | |
# 2. down | |
down_block_res_samples = () | |
for downsample_block in self.down_blocks: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) | |
down_block_res_samples += res_samples | |
# 3. mid | |
if self.mid_block: | |
sample = self.mid_block(sample, timestep_embed) | |
# 4. up | |
for i, upsample_block in enumerate(self.up_blocks): | |
res_samples = down_block_res_samples[-1:] | |
down_block_res_samples = down_block_res_samples[:-1] | |
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) | |
# 5. post-process | |
if self.out_block: | |
sample = self.out_block(sample, timestep_embed) | |
if not return_dict: | |
return (sample,) | |
return UNet1DOutput(sample=sample) | |
from typing import Dict, Optional, Tuple, Union | |
import flax | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
from flax.core.frozen_dict import FrozenDict | |
from ...configuration_utils import ConfigMixin, flax_register_to_config | |
from ...utils import BaseOutput | |
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps | |
from ..modeling_flax_utils import FlaxModelMixin | |
from .unet_2d_blocks_flax import ( | |
FlaxCrossAttnDownBlock2D, | |
FlaxCrossAttnUpBlock2D, | |
FlaxDownBlock2D, | |
FlaxUNetMidBlock2DCrossAttn, | |
FlaxUpBlock2D, | |
) | |
@flax.struct.dataclass | |
class FlaxUNet2DConditionOutput(BaseOutput): | |
""" | |
The output of [`FlaxUNet2DConditionModel`]. | |
Args: | |
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
""" | |
sample: jnp.ndarray | |
@flax_register_to_config | |
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): | |
r""" | |
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample | |
shaped output. | |
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods | |
implemented for all models (such as downloading or saving). | |
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its | |
general usage and behavior. | |
Inherent JAX features such as the following are supported: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
sample_size (`int`, *optional*): | |
The size of the input sample. | |
in_channels (`int`, *optional*, defaults to 4): | |
The number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 4): | |
The number of channels in the output. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): | |
The tuple of downsample blocks to use. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): | |
The tuple of upsample blocks to use. | |
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): | |
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
layers_per_block (`int`, *optional*, defaults to 2): | |
The number of layers per block. | |
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): | |
The dimension of the attention heads. | |
num_attention_heads (`int` or `Tuple[int]`, *optional*): | |
The number of attention heads. | |
cross_attention_dim (`int`, *optional*, defaults to 768): | |
The dimension of the cross attention features. | |
dropout (`float`, *optional*, defaults to 0): | |
Dropout probability for down, up and bottleneck blocks. | |
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): | |
Whether to flip the sin to cos in the time embedding. | |
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. | |
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): | |
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). | |
split_head_dim (`bool`, *optional*, defaults to `False`): | |
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, | |
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. | |
""" | |
sample_size: int = 32 | |
in_channels: int = 4 | |
out_channels: int = 4 | |
down_block_types: Tuple[str, ...] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
) | |
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") | |
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn" | |
only_cross_attention: Union[bool, Tuple[bool]] = False | |
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) | |
layers_per_block: int = 2 | |
attention_head_dim: Union[int, Tuple[int, ...]] = 8 | |
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None | |
cross_attention_dim: int = 1280 | |
dropout: float = 0.0 | |
use_linear_projection: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
flip_sin_to_cos: bool = True | |
freq_shift: int = 0 | |
use_memory_efficient_attention: bool = False | |
split_head_dim: bool = False | |
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1 | |
addition_embed_type: Optional[str] = None | |
addition_time_embed_dim: Optional[int] = None | |
addition_embed_type_num_heads: int = 64 | |
projection_class_embeddings_input_dim: Optional[int] = None | |
def init_weights(self, rng: jax.Array) -> FrozenDict: | |
# init input tensors | |
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) | |
sample = jnp.zeros(sample_shape, dtype=jnp.float32) | |
timesteps = jnp.ones((1,), dtype=jnp.int32) | |
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = {"params": params_rng, "dropout": dropout_rng} | |
added_cond_kwargs = None | |
if self.addition_embed_type == "text_time": | |
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner | |
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim` | |
is_refiner = ( | |
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim | |
== self.config.projection_class_embeddings_input_dim | |
) | |
num_micro_conditions = 5 if is_refiner else 6 | |
text_embeds_dim = self.config.projection_class_embeddings_input_dim - ( | |
num_micro_conditions * self.config.addition_time_embed_dim | |
) | |
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim | |
time_ids_dims = time_ids_channels // self.addition_time_embed_dim | |
added_cond_kwargs = { | |
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), | |
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), | |
} | |
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] | |
def setup(self) -> None: | |
block_out_channels = self.block_out_channels | |
time_embed_dim = block_out_channels[0] * 4 | |
if self.num_attention_heads is not None: | |
raise ValueError( | |
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." | |
) | |
# If `num_attention_heads` is not defined (which is the case for most models) | |
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is. | |
# The reason for this behavior is to correct for incorrectly named variables that were introduced | |
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | |
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking | |
# which is why we correct for the naming here. | |
num_attention_heads = self.num_attention_heads or self.attention_head_dim | |
# input | |
self.conv_in = nn.Conv( | |
block_out_channels[0], | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
# time | |
self.time_proj = FlaxTimesteps( | |
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift | |
) | |
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) | |
only_cross_attention = self.only_cross_attention | |
if isinstance(only_cross_attention, bool): | |
only_cross_attention = (only_cross_attention,) * len(self.down_block_types) | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(self.down_block_types) | |
# transformer layers per block | |
transformer_layers_per_block = self.transformer_layers_per_block | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) | |
# addition embed types | |
if self.addition_embed_type is None: | |
self.add_embedding = None | |
elif self.addition_embed_type == "text_time": | |
if self.addition_time_embed_dim is None: | |
raise ValueError( | |
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" | |
) | |
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) | |
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) | |
else: | |
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") | |
# down | |
down_blocks = [] | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(self.down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
if down_block_type == "CrossAttnDownBlock2D": | |
down_block = FlaxCrossAttnDownBlock2D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
dropout=self.dropout, | |
num_layers=self.layers_per_block, | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
num_attention_heads=num_attention_heads[i], | |
add_downsample=not is_final_block, | |
use_linear_projection=self.use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
dtype=self.dtype, | |
) | |
else: | |
down_block = FlaxDownBlock2D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
dropout=self.dropout, | |
num_layers=self.layers_per_block, | |
add_downsample=not is_final_block, | |
dtype=self.dtype, | |
) | |
down_blocks.append(down_block) | |
self.down_blocks = down_blocks | |
# mid | |
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": | |
self.mid_block = FlaxUNetMidBlock2DCrossAttn( | |
in_channels=block_out_channels[-1], | |
dropout=self.dropout, | |
num_attention_heads=num_attention_heads[-1], | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
use_linear_projection=self.use_linear_projection, | |
use_memory_efficient_attention=self.use_memory_efficient_attention, | |
split_head_dim=self.split_head_dim, | |
dtype=self.dtype, | |
) | |
elif self.config.mid_block_type is None: | |
self.mid_block = None | |
else: | |
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") | |
# up | |
up_blocks = [] | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
only_cross_attention = list(reversed(only_cross_attention)) | |
output_channel = reversed_block_out_channels[0] | |
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | |
for i, up_block_type in enumerate(self.up_block_types): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
is_final_block = i == len(block_out_channels) - 1 | |
if up_block_type == "CrossAttnUpBlock2D": | |
up_block = FlaxCrossAttnUpBlock2D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
num_layers=self.layers_per_block + 1, | |
transformer_layers_per_block=reversed_transformer_layers_per_block[i], |
View raw
(Sorry about that, but we can’t show files that are this big right now.)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment