Skip to content

Commit 96d2ddb

Browse files
author
Thiago Crepaldi
authored
Store user model to simplify ONNXProgram.{adapt_torch_*,__call__} APIs (#115281) (#115583)
Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. Pull Request resolved: #115281 Approved by: https://github.com/BowenBao ghstack dependencies: #114407
1 parent 738b4a5 commit 96d2ddb

File tree

7 files changed

+134
-72
lines changed

7 files changed

+134
-72
lines changed

test/onnx/onnx_test_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,13 @@ def _compare_pytorch_onnx_with_ort(
436436
ref_input_args = input_args
437437
ref_input_kwargs = input_kwargs
438438

439-
# ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
439+
# NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
440440
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
441441
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
442-
ort_outputs = onnx_program(*input_args, model=ref_model, **input_kwargs)
442+
# NOTE: `model_with_state_dict=ref_model` is specified to cover runs with FakeTensor support
443+
ort_outputs = onnx_program(*input_args, **input_kwargs)
443444
ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
444-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_model, ref_outputs)
445+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
445446

446447
if len(ref_outputs) != len(ort_outputs):
447448
raise AssertionError(

test/onnx/test_fx_to_onnx_with_onnxruntime.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -198,23 +198,15 @@ def func(x, b=1.0):
198198
),
199199
)
200200
onnx_test_common.assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
201-
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
202-
tensor_x, model=func, b=8.0
203-
)
204-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
205-
func, func(tensor_x, 8.0)
206-
)
201+
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=8.0)
202+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 8.0))
207203
ort_outputs = onnx_test_common.run_ort(onnx_program, onnx_format_args)
208204
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
209205
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
210206

211207
# test on different non-tensor input - xfail
212-
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
213-
tensor_x, model=func, b=9.0
214-
)
215-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
216-
func, func(tensor_x, 9.0)
217-
)
208+
onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(tensor_x, b=9.0)
209+
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(func(tensor_x, 9.0))
218210
_ = onnx_test_common.run_ort(onnx_program, onnx_format_args)
219211
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
220212
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
@@ -839,10 +831,10 @@ def _test_fx_symbolic_tracer_large_scale_exporter(
839831
kwargs = create_pytorch_only_kwargs()
840832
# Original outputs.
841833
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
842-
model, model(*args, **kwargs)
834+
model(*args, **kwargs)
843835
)
844836
# ORT outputs.
845-
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args, model=model)
837+
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(*args)
846838

847839
# Drop Parameters and buffers added by fx_serialization.save_model_with_external_data
848840
args_not_none = args_not_none[: len(args) - len(kwargs)]
@@ -1077,12 +1069,14 @@ def _test_fake_tensor_mode_exporter(
10771069
args = create_args()
10781070
kwargs = create_kwargs()
10791071
# Original outputs.
1072+
# model_with_state_dict=real_model is used to create non-fake weights
10801073
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
1081-
fake_model, real_model(*args, **kwargs)
1074+
real_model(*args, **kwargs), model_with_state_dict=real_model
10821075
)
10831076
# ORT outputs.
1077+
# model_with_state_dict=real_model is used to create non-fake weights
10841078
args_not_none = onnx_program.adapt_torch_inputs_to_onnx(
1085-
*args, model=real_model, **kwargs
1079+
*args, model_with_state_dict=real_model, **kwargs
10861080
)
10871081

10881082
ort_outputs = onnx_test_common.run_ort(

test/onnx/torch_export/test_torch_export_with_onnxruntime.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ def _compare_onnx_and_torch_exported_program(
3131
# NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
3232
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
3333
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
34-
onnx_outputs = onnx_exported_program(
35-
*input_args, model=torch_exported_program, **input_kwargs
36-
)
34+
onnx_outputs = onnx_exported_program(*input_args, **input_kwargs)
3735
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
3836
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
39-
torch_exported_program, torch_outputs
37+
torch_outputs
4038
)
4139
if len(torch_outputs_onnx_format) != len(onnx_outputs):
4240
raise AssertionError(

torch/onnx/_internal/exporter.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,9 @@ class ONNXProgram:
659659
_fake_context: Final[Optional[ONNXFakeContext]]
660660
_export_exception: Final[Optional[Exception]]
661661
_model_signature: Final[Optional[torch.export.ExportGraphSignature]]
662+
_model_torch: Final[
663+
Optional[Union[torch.nn.Module, Callable, torch_export.ExportedProgram]]
664+
]
662665

663666
@_beartype.beartype
664667
def __init__(
@@ -671,9 +674,13 @@ def __init__(
671674
fake_context: Optional[ONNXFakeContext] = None,
672675
export_exception: Optional[Exception] = None,
673676
model_signature: Optional[torch.export.ExportGraphSignature] = None,
677+
model_torch: Optional[
678+
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
679+
] = None,
674680
):
675681
self._model_proto = model_proto
676682
self._model_signature = model_signature
683+
self._model_torch = model_torch
677684
self._input_adapter = input_adapter
678685
self._output_adapter = output_adapter
679686
self._diagnostic_context = diagnostic_context
@@ -683,7 +690,9 @@ def __init__(
683690
def __call__(
684691
self,
685692
*args: Any,
686-
model: Union[torch.nn.Module, Callable, torch_export.ExportedProgram],
693+
model_with_state_dict: Optional[
694+
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
695+
] = None,
687696
options: Optional[ONNXRuntimeOptions] = None,
688697
**kwargs: Any,
689698
) -> Any:
@@ -692,15 +701,21 @@ def __call__(
692701
Args:
693702
args: The positional inputs to the model.
694703
kwargs: The keyword inputs to the model.
695-
model: The PyTorch model to fetch state from.
704+
model_with_state_dict: The PyTorch model to fetch state from.
705+
Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
696706
options: The options to use for running the model with ONNX Runtime.
697707
698708
Returns:
699709
The model output as computed by ONNX Runtime
700710
"""
701711
import onnxruntime # type: ignore[import]
702712

703-
onnx_input = self.adapt_torch_inputs_to_onnx(*args, model=model, **kwargs)
713+
# model specified by the user has precedence, when specified
714+
model_with_state_dict = model_with_state_dict or self._model_torch
715+
716+
onnx_input = self.adapt_torch_inputs_to_onnx(
717+
*args, model_with_state_dict=model_with_state_dict, **kwargs
718+
)
704719
options = options or ONNXRuntimeOptions()
705720
providers = options.execution_providers or onnxruntime.get_available_providers()
706721
onnx_model = self.model_proto.SerializeToString()
@@ -809,7 +824,7 @@ def fake_context(self) -> Optional[ONNXFakeContext]:
809824
def adapt_torch_inputs_to_onnx(
810825
self,
811826
*model_args,
812-
model: Optional[
827+
model_with_state_dict: Optional[
813828
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
814829
] = None,
815830
**model_kwargs,
@@ -828,8 +843,10 @@ def adapt_torch_inputs_to_onnx(
828843
This method replays the adapting steps recorded during export.
829844
830845
Args:
831-
model: The PyTorch model to get extra state from. If not specified, the model used during export is used.
832846
model_args: The PyTorch model inputs.
847+
model_with_state_dict: The PyTorch model to get extra state from.
848+
If not specified, the model used during export is used.
849+
Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
833850
model_kwargs: The PyTorch model keyword inputs.
834851
835852
Returns:
@@ -841,7 +858,7 @@ def adapt_torch_inputs_to_onnx(
841858
>>> import torch
842859
>>> import torch.onnx
843860
>>> from typing import Dict, Tuple
844-
>>> def func_with_nested_input_structure(
861+
>>> def func_nested_input(
845862
... x_dict: Dict[str, torch.Tensor],
846863
... y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
847864
... ):
@@ -857,23 +874,32 @@ def adapt_torch_inputs_to_onnx(
857874
... return x + y1 + y2 + y3
858875
>>> x_dict = {"a": torch.tensor(1.)}
859876
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
860-
>>> onnx_program = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple)
877+
>>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
861878
>>> print(x_dict, y_tuple)
862879
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
863-
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model=func_with_nested_input_structure))
880+
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
864881
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))
865882
866883
.. warning::
867884
This API is experimental and is *NOT* backward-compatible.
868885
869886
"""
870-
return self._input_adapter.apply(*model_args, model=model, **model_kwargs)
887+
# model specified by the user has precedence, when specified
888+
model_with_state_dict = model_with_state_dict or self._model_torch
889+
assert (
890+
model_with_state_dict is not None
891+
), "model_with_state_dict must be specified."
892+
return self._input_adapter.apply(
893+
*model_args, model=model_with_state_dict, **model_kwargs
894+
)
871895

872896
@_beartype.beartype
873897
def adapt_torch_outputs_to_onnx(
874898
self,
875-
model: Union[torch.nn.Module, Callable, torch_export.ExportedProgram],
876899
model_outputs: Any,
900+
model_with_state_dict: Optional[
901+
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
902+
] = None,
877903
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
878904
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
879905
@@ -891,6 +917,9 @@ def adapt_torch_outputs_to_onnx(
891917
Args:
892918
model: The PyTorch model to get extra state from.
893919
model_outputs: The PyTorch model outputs.
920+
model_with_state_dict: The PyTorch model to get extra state from.
921+
If not specified, the model used during export is used.
922+
Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
894923
895924
Returns:
896925
PyTorch model outputs in exported ONNX model outputs format.
@@ -912,14 +941,19 @@ def adapt_torch_outputs_to_onnx(
912941
>>> pt_output = func_returning_tuples(x, y, z)
913942
>>> print(pt_output)
914943
(tensor(3.), (tensor(5.), tensor(8.)))
915-
>>> print(onnx_program.adapt_torch_outputs_to_onnx(func_returning_tuples, pt_output))
944+
>>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
916945
[tensor(3.), tensor(5.), tensor(8.)]
917946
918947
.. warning::
919948
This API is experimental and is *NOT* backward-compatible.
920949
921950
"""
922-
return self._output_adapter.apply(model, model_outputs)
951+
# model specified by the user has precedence, when specified
952+
model_with_state_dict = model_with_state_dict or self._model_torch
953+
assert (
954+
model_with_state_dict is not None
955+
), "model_with_state_dict must be specified."
956+
return self._output_adapter.apply(model_outputs, model=model_with_state_dict)
923957

924958
@_beartype.beartype
925959
def save(
@@ -1053,6 +1087,7 @@ def _from_failure(
10531087
# https://github.com/pytorch/pytorch/issues/103764
10541088
import onnx
10551089

1090+
# TODO: Should we populate ONNXProgram with more info, such _model_torch for easier debug?
10561091
return ONNXProgram(
10571092
onnx.ModelProto(), # type: ignore[attr-defined]
10581093
io_adapter.InputAdapter(),
@@ -1182,6 +1217,7 @@ def export(self) -> ONNXProgram:
11821217
model_signature=getattr(
11831218
self.model, "graph_signature", None
11841219
), # Available for isinstance(self.model, ExportedProgram) only
1220+
model_torch=self.model,
11851221
)
11861222

11871223
def _assert_fake_tensor_mode(self):

torch/onnx/_internal/fx/dynamo_graph_extractor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,14 @@ def __init__(
132132

133133
def apply(
134134
self,
135-
model: Union[torch.nn.Module, Callable, torch_export.ExportedProgram],
136135
model_outputs: Any,
136+
model: Optional[
137+
Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
138+
] = None,
137139
) -> Sequence[Any]:
138140
"""Flatten the model outputs, under the context of pytree extension."""
139141
with self._pytree_extension_context:
140-
return super().apply(model, model_outputs)
142+
return super().apply(model_outputs, model=model)
141143

142144

143145
def _wrap_model_with_output_adapter(
@@ -163,7 +165,7 @@ def _wrap_model_with_output_adapter(
163165
# Preserve original function signature.
164166
@functools.wraps(model_func)
165167
def wrapped(*args, **kwargs):
166-
return output_adapter.apply(model, model_func(*args, **kwargs))
168+
return output_adapter.apply(model_func(*args, **kwargs), model=model)
167169

168170
return wrapped
169171

torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _trace_into_fx_graph_via_fx_symbolic_trace(
169169
torch.onnx.utils.model_signature(model)
170170
)
171171
self.input_adapter.append_step(bind_input_step)
172-
_, named_args = bind_input_step.apply(model, model_args, model_kwargs)
172+
_, named_args = bind_input_step.apply(model_args, model_kwargs, model=model)
173173

174174
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
175175
# Example content of concrete_args:

0 commit comments

Comments
 (0)