Last active
October 30, 2022 01:52
-
-
Save drscotthawley/81865a5c5e729b769486efb9c3f2249d to your computer and use it in GitHub Desktop.
Wrapper to give einops.rearrange an "inverse"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from einops import rearrange as _rearrange | |
class RearrangeWrapper(): | |
"wrapper to endow einops.rearrange with an 'inverse' operation" | |
def __init__(self): | |
self.shape, self.s = None, None # just in case someone tries to call inverse first | |
def __call__(self, x, s:str, **kwargs): # this 'forward' call is lightweight to preserve original usage | |
self.shape, self.s = x.shape, s | |
return _rearrange(x, s, **kwargs) | |
def inverse(self, | |
y, # torch tensor, e.g., result of forward call | |
infer_dim='', # axis-letter (from self.s) to try to infer | |
): | |
assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method" | |
split = self.s.split('->') # get 'before' and 'after' strings of forward transform | |
axes = split[0].strip().split(' ') # get axis letters, assuming they're space-separated before '->' | |
assert len(axes) == len(self.shape) | |
axes_info = {axes[i]:self.shape[i] for i in range(len(self.shape)) } | |
if infer_dim in axes_info.keys(): axes_info.pop(infer_dim) | |
return _rearrange(y, split[1]+' -> '+split[0], **axes_info) | |
# only have to instantiate this once for the rest of the code, even with different parameters/dims | |
rearrange = RearrangeWrapper() |
One limitation of the above implementation is if you want to "invert" more than one call "into the past", it doesn't support that:
x = torch.rand((5,2,4))
x2 = torch.rand((2,6,5,4))
y = rearrange(x, 'b c n -> n (c b)')
y2 = rearrange(x2, 'b c w h -> h (c b) w')
yinv = rearrange.inverse(y) # this won't work because it doesn't "remember" the settings from the first call
This could be perhaps be fixed by having the forward call optionally also return some "archival" info that could be passed in later to inverse. 🤷 For example:
class RearrangeWrapperArchive():
"wrapper to endow einops.rearrange with an 'inverse' operation"
def __init__(self):
self.shape, self.s = None, None # just in case someone tries to call inverse first
def __call__(self, x, s:str, archive=False, **kwargs): # this 'forward' call is lightweight to preserve original usage
self.shape, self.s = x.shape, s
return _rearrange(x, s, **kwargs), (self.shape, self.s) if archive else _rearrange(x, s, **kwargs)
def inverse(self,
y, # torch tensor, e.g., result of forward call
infer_dim='', # axis-letter (from self.s) to try to infer
archive=None, # tuple of shape & s from a previous call
):
if archive is not None: (self.shape, self.s) = archive
assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method"
split = self.s.split('->') # get 'before' and 'after' strings of forward transform
axes = split[0].strip().split(' ') # get axis letters, assuming they're space-separated before '->'
assert len(axes) == len(self.shape)
axes_info = {axes[i]:self.shape[i] for i in range(len(self.shape)) }
if infer_dim in axes_info.keys(): axes_info.pop(infer_dim)
return _rearrange(y, split[1]+' -> '+split[0], **axes_info)
# only have to instantiate this once for the rest of the code, even with different parameters/dims
rearrange = RearrangeWrapperArchive()
x = torch.rand((5,2,4))
y, a1 = rearrange(x, 'b c n -> n (c b)', archive=True)
x2 = torch.rand((2,6,5,4))
y2, a2 = rearrange(x2, 'b c w h -> h (c b) w', archive=True)
print(x.shape, y.shape, rearrange.inverse(y, archive=a1).shape)
print(x2.shape, y2.shape, rearrange.inverse(y2, archive=a2).shape)
Output:
torch.Size([5, 2, 4]) torch.Size([4, 10]) torch.Size([5, 2, 4])
torch.Size([2, 6, 5, 4]) torch.Size([4, 12, 5]) torch.Size([2, 6, 5, 4])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
My response to an idea by Francois Fleuret:
Colab notebook: https://colab.research.google.com/drive/1chDM3IPyyn_KNV_rF-XqS666FT58oXBe?usp=sharing
Usage examples:
Example 1:
Output
Example 2:
Output:
Example 3:
Output:
Example 4:
Output:
Example 5:
Output: