PyTorchのバックエンドとしてMPSを使い、Stable DiffusionがM1 Macで動いたと聞いた。MPSはMetal Performance Shaderのことらしい。
ほい? MetalならIntel MacのRadeonでも動くのでは?としてやってみた。
環境
- 2.3 GHz 8コアIntel Core i9
- AMD Radeon Pro 5500M 8 GB
- macOS Monterey 12.5.1
- Homebrewで入れたminiforge
追記4
GitHubに上げました。
普通に入れる
以下を参考にした:
https://rentry.org/SDInstallGuide
ダウンロードする。
% git clone https://github.com/CompVis/stable-diffusion.git
% cd stable-diffusion
environment.yamlを編集する。
CUDAを使わない:
- - cudatoolkit=11.3
+ # - cudatoolkit=11.3
MPS対応:
- - pytorch
+ - pytorch-nightly
- - pytorch=1.11.0
- - torchvision=0.12.0
+ - pytorch
+ - torchvision
仮想環境を用意する。
% conda env create -f environment.yaml
% conda activate ldm
% mkdir -p models/ldm/stable-diffusion-v1
https://huggingface.co/CompVis/stable-diffusion-v-1-4-originalでLog InやらAccess Repositoryし、https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/blob/main/sd-v1-4.ckptを(~/Downloadsに)ダウンロードする。
モデルを移動しリネームする。
% mv ~/Downloads/sd-v1-4.ckpt models/ldm/stable-diffusion-v1/model.ckpt
実行してみる。
% python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
動かない。
mpsが動くか確認
以下を参考にした:
mpsを認識するか確認する。
% python
>>> import torch
>>> torch.device('mps')
device(type='mps')
良さそう。
以下を実行してみる(同記事より引用)。
pytorch_m1_macbook.py
# -*- coding: utf-8 -*-
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as tt
from torchvision.models import resnet18
import os
from argparse import ArgumentParser
import time
def main(device):
# ResNetのハイパーパラメータ
n_epoch = 5 # エポック数
batch_size = 512 # ミニバッチサイズ
momentum = 0.9 # SGDのmomentum
lr = 0.01 # 学習率
weight_decay = 0.00005 # weight decay
# 訓練データとテストデータを用意
mean = (0.491, 0.482, 0.446)
std = (0.247, 0.243, 0.261)
train_transform = tt.Compose([
tt.RandomHorizontalFlip(p=0.5),
tt.RandomCrop(size=32, padding=4, padding_mode='reflect'),
tt.ToTensor(),
tt.Normalize(mean=mean, std=std)
])
test_transform = tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)])
root = os.path.dirname(os.path.abspath(__file__))
train_set = CIFAR10(root=root, train=True,
download=True, transform=train_transform)
train_loader = DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=8)
# ResNetの準備
resnet = resnet18()
resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)
# 訓練
criterion = CrossEntropyLoss()
optimizer = SGD(resnet.parameters(), lr=lr,
momentum=momentum, weight_decay=weight_decay)
train_start_time = time.time()
resnet.to(device)
resnet.train()
for epoch in range(1, n_epoch+1):
train_loss = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
optimizer.zero_grad()
outputs = resnet(inputs)
labels = labels.to(device)
loss = criterion(outputs, labels)
loss.backward()
train_loss += loss.item()
del loss # メモリ節約のため
optimizer.step()
print('Epoch {} / {}: time = {}[s], loss = {:.2f}'.format(
epoch, n_epoch, time.time() - train_start_time, train_loss))
print('Train time on {}: {:.2f}[s] (Train loss = {:.2f})'.format(
device, time.time() - train_start_time, train_loss))
# 評価
test_set = CIFAR10(root=root, train=False, download=True,
transform=test_transform)
test_loader = DataLoader(test_set, batch_size=batch_size,
shuffle=False, num_workers=8)
test_loss = 0.0
test_start_time = time.time()
resnet.eval()
for inputs, labels in test_loader:
inputs = inputs.to(device)
outputs = resnet(inputs)
labels = labels.to(device)
loss = criterion(outputs, labels)
test_loss += loss.item()
print('Test time on {}: {:.2f}[s](Test loss = {:.2f})'.format(
device, time.time() - test_start_time, test_loss))
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--device', type=str, default='mps',
choices=['cpu', 'mps'])
args = parser.parse_args()
device = torch.device(args.device)
main(device)
CPUの場合(抜粋):
% python pytorch_m1_macbook.py --device cpu
Epoch 1 / 5: time = 249.9817099571228[s], loss = 170.60
Epoch 2 / 5: time = 498.5888819694519[s], loss = 137.21
Epoch 3 / 5: time = 762.4725549221039[s], loss = 122.71
Epoch 4 / 5: time = 1022.5609741210938[s], loss = 112.18
Epoch 5 / 5: time = 1274.3697321414948[s], loss = 103.73
Train time on cpu: 1274.37[s] (Train loss = 103.73)
Test time on cpu: 58.76[s](Test loss = 20.09)
GPUの場合(抜粋):
% python pytorch_m1_macbook.py --device mps
Epoch 1 / 5: time = 131.3166902065277[s], loss = 170.33
Epoch 2 / 5: time = 246.86656522750854[s], loss = 137.14
Epoch 3 / 5: time = 362.39308524131775[s], loss = 122.12
Epoch 4 / 5: time = 478.34768986701965[s], loss = 113.14
Epoch 5 / 5: time = 594.5503239631653[s], loss = 104.61
Train time on mps: 594.55[s] (Train loss = 104.61)
est time on mps: 59.96[s](Test loss = 20.42)
2倍程度に速くなった。
Stable Diffusionのコードを修正
以下を参考にした:
4つのファイルを編集。
- scripts/txt2img.py
- ldm/models/diffusion/plms.py
- configs/stable-diffusion/v1-inference.yamlconfigs/stable-diffusion/v1-inference.yaml
- /usr/local/Caskroom/miniforge/base/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py
- ldm/modules/attention.py
scripts/txt2img.py
print("unexpected keys:")
print(u)
- model.cuda()
+ # model.cuda()
+ model.to("mps")
model.eval()
return model
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = model.to(device)
if opt.plms:
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
- with precision_scope("cuda"):
+ # with precision_scope("cuda"):
+ with nullcontext("mps"):
with model.ema_scope():
tic = time.time()
all_samples = list()
ldm/models/diffusion/plms.py
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
+ if attr.device != torch.device("mps"):
+ attr = attr.to(torch.float32).to(torch.device("mps")).contiguous()
setattr(self, name, attr)
configs/stable-diffusion/v1-inference.yamlconfigs/stable-diffusion/v1-inference.yaml(編集必要でした)
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params: # edit
+ device: mps # edit
/usr/local/Caskroom/miniforge/base/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py
return handle_torch_function(
layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
)
- return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
+ return torch.layer_norm(input.contiguous(), normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) # edit
ldm/modules/attention.py
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
+ x = x.contiguous() # edit
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
改めて実行する。
% python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
動いた!
(下段中央の画像はNSFW filterに引っかかったみたい。)
問題点
画像6枚で1時間以上かかりました。
ただ、GPUが10-20%程度しか使われてませんでした(以下記事追記に類似)。
フルで使われるなど、更なる高速化ができないか探してみます。
追記
よくわかりませんが、画像2枚(--n_samples 1 --n_rows 1
)の方がGPU使ってくれて(70-80%くらい)、6分くらいで終わります。
あと、--n_rows 1
としているのにグリッドが2行になるのもよくわからないです。
追記2
--n_iter
:グリッドの行数(デフォルト2)--n_rows
:グリッド1行に何枚か(デフォルト=n_samples
)
--n_samples 1 --n_iter 1
とすべきだった。
画像1枚3分くらいでできた。
追記3
torch/nn/functional.pyではなくldm/modules/attention.pyを編集するように変更しました。これでライブラリに手を加えなくて済みます。
目下の興味は--n_samples 2
のときにAppleInternal/(中略)/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'
というエラーが出ること。MPS側の問題でしょうか。--n_samples 3
でも動くのになぜ。