Replies: 1 comment 3 replies
-
|
what's the expected output of concating two sharded tensors that shared devices? even better is to add it as a failed test in test_multitensor.py. then we can discuss how to make it work as expected |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am making this writeup to show some of my findings and get some feedback.
What
The following is not supported by tinygrad
This is due to
Tensor.catnot working along the sharded axis (0). It tries to padaandbbut fails an assert blocking that from happening on the sharded axis.Why
When one performs inference with Stable Diffusion, you go through a process called Classifier Free Guidance. The TLDR is that one needs to run 2 separate samples through the model (for each denoising step), one conditioned on the text prompt and the other conditioned on an empty prompt.
To make this more efficient, most implementations will cat together the two calls and run that through the model a single time, before chunking the output like such:
This is how I implemented the tinygrad version of SDXL and has worked fine, until I went to run this on multiple GPUs. Since the
Tensor.catoccurs along the batch dim, it runs into the issue since this is also the shard axis.I originally tried taking the lazy approach and just splitting it into 2 seperate calls of the model. This does work, but George pushed me not just hack things and search for a proper solution.
How
Chenyu suggested an approach where lazybuffers and devices are concatenated when one attempts to cat 2 MultiLazyBuffers. This is doable, but ends up looking quite ugly with all of the checks needed.
This also only solves half of the problem and the same needs to be implemented for
Tensor.shrinkfor when one wants to chunk.I took a look at using
MultiLazyBuffer.realbut this just feels like a cheap hack that would cause more problems.One could add a similar looking
ifstatement to the one above intoTensor.shrinkthat checks for aMultiLazyBufferwith devices along the bounds and returns a subset of the current devices and lazybuffers.I also have not looked into the implications this might have on the backwards pass.
Feedback
Beta Was this translation helpful? Give feedback.
All reactions