Skip to content

Commit fe5cc79

Browse files
committed
fixes dream.py mps seed
1 parent 361cc42 commit fe5cc79

1 file changed

Lines changed: 18 additions & 6 deletions

File tree

ldm/simplet2i.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,23 @@ def process_image(image,seed):
348348

349349
def get_noise():
350350
if init_img:
351-
return torch.randn_like(init_latent, device=self.device)
351+
if self.device.type == 'mps':
352+
return torch.randn_like(init_latent, device='cpu').to(self.device)
353+
else:
354+
return torch.randn_like(init_latent, device=self.device)
352355
else:
353-
return torch.randn([1,
354-
self.latent_channels,
355-
height // self.downsampling_factor,
356-
width // self.downsampling_factor],
357-
device=self.device)
356+
if self.device.type == 'mps':
357+
return torch.randn([1,
358+
self.latent_channels,
359+
height // self.downsampling_factor,
360+
width // self.downsampling_factor],
361+
device='cpu').to(self.device)
362+
else:
363+
return torch.randn([1,
364+
self.latent_channels,
365+
height // self.downsampling_factor,
366+
width // self.downsampling_factor],
367+
device=self.device)
358368

359369
initial_noise = None
360370
if variation_amount > 0 or len(with_variations) > 0:
@@ -383,6 +393,8 @@ def get_noise():
383393
x_T = initial_noise
384394
else:
385395
seed_everything(seed)
396+
if self.device.type == 'mps':
397+
x_T = get_noise()
386398
# make_image will do the equivalent of get_noise itself
387399
image = make_image(x_T)
388400
results.append([image, seed])

0 commit comments

Comments
 (0)