-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathpatch.py
279 lines (203 loc) · 10.5 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import torch
import math
from typing import Type, Dict, Any, Tuple, Callable
from . import merge
from .utils import isinstance_str, init_generator
def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]:
original_h, original_w = tome_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
args = tome_info["args"]
if downsample <= args["max_downsample"]:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * args["ratio"])
# Re-init the generator if it hasn't already been initialized or device has changed.
if args["generator"] is None:
args["generator"] = init_generator(x.device)
elif args["generator"].device != x.device:
args["generator"] = init_generator(x.device, fallback=args["generator"])
# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
# batch, which causes artifacts with use_rand, so force it to be off.
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
no_rand=not use_rand, generator=args["generator"])
else:
m, u = (merge.do_nothing, merge.do_nothing)
m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing, merge.do_nothing)
m_c, u_c = (m, u) if args["merge_crossattn"] else (merge.do_nothing, merge.do_nothing)
m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing, merge.do_nothing)
return m_a, m_c, m_m, u_a, u_c, u_m # Okay this is probably not very good
def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)
# This is where the meat of the computation happens
x = u_a(self.attn1(m_a(self.norm1(x)), context=context if self.disable_self_attn else None)) + x
x = u_c(self.attn2(m_c(self.norm2(x)), context=context)) + x
x = u_m(self.ff(m_m(self.norm3(x)))) + x
return x
return ToMeBlock
def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class for a diffusers model.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
) -> torch.Tensor:
# (1) ToMe
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info)
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# (2) ToMe m_a
norm_hidden_states = m_a(norm_hidden_states)
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
# (3) ToMe u_a
hidden_states = u_a(attn_output) + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# (4) ToMe m_c
norm_hidden_states = m_c(norm_hidden_states)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
# (5) ToMe u_c
hidden_states = u_c(attn_output) + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
# (6) ToMe m_m
norm_hidden_states = m_m(norm_hidden_states)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
# (7) ToMe u_m
hidden_states = u_m(ff_output) + hidden_states
return hidden_states
return ToMeBlock
def hook_tome_model(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
return None
model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
def apply_patch(
model: torch.nn.Module,
ratio: float = 0.5,
max_downsample: int = 1,
sx: int = 2, sy: int = 2,
use_rand: bool = True,
merge_attn: bool = True,
merge_crossattn: bool = False,
merge_mlp: bool = False):
"""
Patches a stable diffusion model with ToMe.
Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model).
Important Args:
- model: A top level Stable Diffusion module to patch in place. Should have a ".model.diffusion_model"
- ratio: The ratio of tokens to merge. I.e., 0.4 would reduce the total number of tokens by 40%.
The maximum value for this is 1-(1/(sx*sy)). By default, the max is 0.75 (I recommend <= 0.5 though).
Higher values result in more speed-up, but with more visual quality loss.
Args to tinker with if you want:
- max_downsample [1, 2, 4, or 8]: Apply ToMe to layers with at most this amount of downsampling.
E.g., 1 only applies to layers with no downsampling (4/15) while
8 applies to all layers (15/15). I recommend a value of 1 or 2.
- sx, sy: The stride for computing dst sets (see paper). A higher stride means you can merge more tokens,
but the default of (2, 2) works well in most cases. Doesn't have to divide image size.
- use_rand: Whether or not to allow random perturbations when computing dst sets (see paper). Usually
you'd want to leave this on, but if you're having weird artifacts try turning this off.
- merge_attn: Whether or not to merge tokens for attention (recommended).
- merge_crossattn: Whether or not to merge tokens for cross attention (not recommended).
- merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended).
"""
# Make sure the module is not currently patched
remove_patch(model)
is_diffusers = isinstance_str(model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin")
if not is_diffusers:
if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"):
# Provided model not supported
raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
diffusion_model = model.model.diffusion_model
else:
# Supports "pipe.unet" and "unet"
diffusion_model = model.unet if hasattr(model, "unet") else model
diffusion_model._tome_info = {
"size": None,
"hooks": [],
"args": {
"ratio": ratio,
"max_downsample": max_downsample,
"sx": sx, "sy": sy,
"use_rand": use_rand,
"generator": None,
"merge_attn": merge_attn,
"merge_crossattn": merge_crossattn,
"merge_mlp": merge_mlp
}
}
hook_tome_model(diffusion_model)
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = diffusion_model._tome_info
# Something introduced in SD 2.0 (LDM only)
if not hasattr(module, "disable_self_attn") and not is_diffusers:
module.disable_self_attn = False
# Something needed for older versions of diffusers
if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers:
module.use_ada_layer_norm = False
module.use_ada_layer_norm_zero = False
return model
def remove_patch(model: torch.nn.Module):
""" Removes a patch from a ToMe Diffusion module if it was already patched. """
# For diffusers
model = model.unet if hasattr(model, "unet") else model
for _, module in model.named_modules():
if hasattr(module, "_tome_info"):
for hook in module._tome_info["hooks"]:
hook.remove()
module._tome_info["hooks"].clear()
if module.__class__.__name__ == "ToMeBlock":
module.__class__ = module._parent
return model