Skip to content

Instantly share code, notes, and snippets.

@drscotthawley
Last active October 30, 2022 01:52
Show Gist options
  • Save drscotthawley/81865a5c5e729b769486efb9c3f2249d to your computer and use it in GitHub Desktop.
Save drscotthawley/81865a5c5e729b769486efb9c3f2249d to your computer and use it in GitHub Desktop.
Wrapper to give einops.rearrange an "inverse"
from einops import rearrange as _rearrange
class RearrangeWrapper():
"wrapper to endow einops.rearrange with an 'inverse' operation"
def __init__(self):
self.shape, self.s = None, None # just in case someone tries to call inverse first
def __call__(self, x, s:str, **kwargs): # this 'forward' call is lightweight to preserve original usage
self.shape, self.s = x.shape, s
return _rearrange(x, s, **kwargs)
def inverse(self,
y, # torch tensor, e.g., result of forward call
infer_dim='', # axis-letter (from self.s) to try to infer
):
assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method"
split = self.s.split('->') # get 'before' and 'after' strings of forward transform
axes = split[0].strip().split(' ') # get axis letters, assuming they're space-separated before '->'
assert len(axes) == len(self.shape)
axes_info = {axes[i]:self.shape[i] for i in range(len(self.shape)) }
if infer_dim in axes_info.keys(): axes_info.pop(infer_dim)
return _rearrange(y, split[1]+' -> '+split[0], **axes_info)
# only have to instantiate this once for the rest of the code, even with different parameters/dims
rearrange = RearrangeWrapper()
@drscotthawley
Copy link
Author

drscotthawley commented Oct 29, 2022

My response to an idea by Francois Fleuret:

Colab notebook: https://colab.research.google.com/drive/1chDM3IPyyn_KNV_rF-XqS666FT58oXBe?usp=sharing

Usage examples:

Example 1:

import torch

x = torch.ones(3,4,5,6)
print(x.shape)
y = rearrange(x, 'n c h w -> n (c h w)')
z = rearrange.inverse(y)
print(y.shape)
print(z.shape)
assert x.shape == z.shape 

Output

torch.Size([3, 4, 5, 6])
torch.Size([3, 120])
torch.Size([3, 4, 5, 6])

Example 2:

x = torch.ones(4,8,17)
y = rearrange(x, 'c h w -> (c h w)')
z = rearrange.inverse(y, infer_dim='h')     # infer_dim is unnecessary here but it doesn't break anything
print(y.shape)
print(z.shape)
assert x.shape == z.shape 

Output:

torch.Size([544])
torch.Size([4, 8, 17])

Example 3:

x = torch.ones(3,4,5,6)
y = rearrange(x, 'n c h w -> n (c h w)')
print("x.shape =",x.shape,"\ny.shape =",y.shape)
y2 = torch.ones(y.shape[0], 2*5*6)          # test infer_dim: change number of channels from 4 to 2
z = rearrange.inverse(y2, infer_dim='c')
print("y2.shape =",y2.shape,"\nz.shape =",z.shape)
assert z.shape == torch.Size([3, 2, 5, 6])

Output:

x.shape =  torch.Size([3, 4, 5, 6])
y.shape = torch.Size([3, 120])
y2.shape = torch.Size([3, 60])
z.shape = torch.Size([3, 2, 5, 6])

Example 4:

import numpy as np

q = np.random.rand(8,2,16)
qprime = rearrange(q, 'b c n -> (b n) c ')
print("q.shape =",q.shape,"\nqprime.shape =",qprime.shape)
qpp = rearrange.inverse(qprime)
print("qpp.shape =",qpp.shape)
assert np.array_equal(q, qpp)

Output:

q.shape = (8, 2, 16) 
qprime.shape = (128, 2)
qpp.shape = (8, 2, 16)

Example 5:

rearrange2 = RearrangeWrapper()     # test what happens if you call inverse first
z = rearrange2.inverse(y2, infer_dim='c')

Output:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-18-7d50cc725597>](https://localhost:8080/#) in <module>
      1 rearrange2 = RearrangeWrapper()     # test what happens if you call inverse first
----> 2 z = rearrange2.inverse(y2, infer_dim='c')

[<ipython-input-14-3426d56c3772>](https://localhost:8080/#) in inverse(self, y, infer_dim)
     11         infer_dim='',                       # axis-letter (from self.s) to try to infer
     12         ):  
---> 13         assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method"
     14         split = self.s.split('->')          # get 'before' and 'after' strings of forward transform
     15         axes = split[0].strip().split(' ')  # get axis letters, assuming they're space-separated before '->'

AssertionError: inverse called before forward method

@drscotthawley
Copy link
Author

drscotthawley commented Oct 30, 2022

One limitation of the above implementation is if you want to "invert" more than one call "into the past", it doesn't support that:

x = torch.rand((5,2,4))
x2 = torch.rand((2,6,5,4))
y = rearrange(x, 'b c n -> n (c b)')
y2 = rearrange(x2, 'b c w h -> h (c b) w')

yinv = rearrange.inverse(y)  # this won't work because it doesn't "remember" the settings from the first call

This could be perhaps be fixed by having the forward call optionally also return some "archival" info that could be passed in later to inverse. 🤷 For example:

class RearrangeWrapperArchive():
    "wrapper to endow einops.rearrange with an 'inverse' operation"
    def __init__(self):
        self.shape, self.s = None, None     # just in case someone tries to call inverse first
      
    def __call__(self, x, s:str, archive=False, **kwargs): # this 'forward' call is lightweight to preserve original usage
        self.shape, self.s = x.shape, s
        return _rearrange(x, s, **kwargs), (self.shape, self.s) if archive else _rearrange(x, s, **kwargs) 

    def inverse(self, 
        y,                                  # torch tensor, e.g., result of forward call
        infer_dim='',                       # axis-letter (from self.s) to try to infer 
        archive=None,                       # tuple of shape & s from a previous call
        ):
        if archive is not None: (self.shape, self.s) = archive
        assert ((self.shape is not None) and (self.s is not None)), "inverse called before forward method"
        split = self.s.split('->')          # get 'before' and 'after' strings of forward transform
        axes = split[0].strip().split(' ')  # get axis letters, assuming they're space-separated before '->'
        assert len(axes) == len(self.shape)
        axes_info = {axes[i]:self.shape[i] for i in range(len(self.shape)) }
        if infer_dim in axes_info.keys(): axes_info.pop(infer_dim)
        return _rearrange(y, split[1]+' -> '+split[0], **axes_info)

# only have to instantiate this once for the rest of the code, even with different parameters/dims
rearrange = RearrangeWrapperArchive()  
x = torch.rand((5,2,4))
y, a1 = rearrange(x, 'b c n -> n (c b)', archive=True)
x2 = torch.rand((2,6,5,4))
y2, a2 = rearrange(x2, 'b c w h -> h (c b) w', archive=True)
print(x.shape, y.shape, rearrange.inverse(y, archive=a1).shape)  
print(x2.shape, y2.shape, rearrange.inverse(y2, archive=a2).shape)   

Output:

torch.Size([5, 2, 4]) torch.Size([4, 10]) torch.Size([5, 2, 4])
torch.Size([2, 6, 5, 4]) torch.Size([4, 12, 5]) torch.Size([2, 6, 5, 4])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment