Skip to content

torch.jit.script does not work with Tensorized Models #33

@hello-fri-end

Description

@hello-fri-end

Minimal Code:

import torch
from torch.nn import Module
from tltorch import FactorizedConv

class Test(Module):
    def __init__(self):
        super(Test, self).__init__()
        self.layer = FactorizedConv(3, 4, 3, factorization='tucker', order=3)

def main():
# Instantiate the model
    model = Test()
    scripted_module = torch.jit.script(model)

if __name__ == "__main__":
    main()

Error:

Traceback (most recent call last):
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 27, in <module>
    main()
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 24, in main
    save_model(model)
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 8, in save_model
    scripted_module = torch.jit.script(model)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 572, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 899, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 87, in make_stub_from_method
    return make_stub(func, method_name)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 71, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 372, in get_jit_def
    return build_def(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 422, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 448, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/tltorch/factorized_tensors/core.py", line 259
    def forward(self, indices=None, **kwargs):
                                     ~~~~~~~ <--- HERE
        """To use a tensor factorization within a network, use ``tensor.forward``, or, equivalently, ``tensor()`

The main issue here is torch.jit.script doesn't support variable number of arguments and keyword-only arguments with defaults which are present in the forward function of the factorized/tensorized layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions