File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -348,13 +348,23 @@ def process_image(image,seed):
348348
349349 def get_noise ():
350350 if init_img :
351- return torch .randn_like (init_latent , device = self .device )
351+ if self .device .type == 'mps' :
352+ return torch .randn_like (init_latent , device = 'cpu' ).to (self .device )
353+ else :
354+ return torch .randn_like (init_latent , device = self .device )
352355 else :
353- return torch .randn ([1 ,
354- self .latent_channels ,
355- height // self .downsampling_factor ,
356- width // self .downsampling_factor ],
357- device = self .device )
356+ if self .device .type == 'mps' :
357+ return torch .randn ([1 ,
358+ self .latent_channels ,
359+ height // self .downsampling_factor ,
360+ width // self .downsampling_factor ],
361+ device = 'cpu' ).to (self .device )
362+ else :
363+ return torch .randn ([1 ,
364+ self .latent_channels ,
365+ height // self .downsampling_factor ,
366+ width // self .downsampling_factor ],
367+ device = self .device )
358368
359369 initial_noise = None
360370 if variation_amount > 0 or len (with_variations ) > 0 :
@@ -383,6 +393,8 @@ def get_noise():
383393 x_T = initial_noise
384394 else :
385395 seed_everything (seed )
396+ if self .device .type == 'mps' :
397+ x_T = get_noise ()
386398 # make_image will do the equivalent of get_noise itself
387399 image = make_image (x_T )
388400 results .append ([image , seed ])
You can’t perform that action at this time.
0 commit comments