Commit 96d2ddb
Thiago Crepaldi
Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.
This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.
This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).
As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.
Pull Request resolved: #115281
Approved by: https://github.com/BowenBao
ghstack dependencies: #114407
1 parent 738b4a5 commit 96d2ddb
File tree
7 files changed
+134
-72
lines changed- test/onnx
- torch_export
- torch/onnx/_internal
- fx
7 files changed
+134
-72
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
436 | 436 | | |
437 | 437 | | |
438 | 438 | | |
439 | | - | |
| 439 | + | |
440 | 440 | | |
441 | 441 | | |
442 | | - | |
| 442 | + | |
| 443 | + | |
443 | 444 | | |
444 | | - | |
| 445 | + | |
445 | 446 | | |
446 | 447 | | |
447 | 448 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
198 | 198 | | |
199 | 199 | | |
200 | 200 | | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
| 201 | + | |
| 202 | + | |
207 | 203 | | |
208 | 204 | | |
209 | 205 | | |
210 | 206 | | |
211 | 207 | | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
| 208 | + | |
| 209 | + | |
218 | 210 | | |
219 | 211 | | |
220 | 212 | | |
| |||
839 | 831 | | |
840 | 832 | | |
841 | 833 | | |
842 | | - | |
| 834 | + | |
843 | 835 | | |
844 | 836 | | |
845 | | - | |
| 837 | + | |
846 | 838 | | |
847 | 839 | | |
848 | 840 | | |
| |||
1077 | 1069 | | |
1078 | 1070 | | |
1079 | 1071 | | |
| 1072 | + | |
1080 | 1073 | | |
1081 | | - | |
| 1074 | + | |
1082 | 1075 | | |
1083 | 1076 | | |
| 1077 | + | |
1084 | 1078 | | |
1085 | | - | |
| 1079 | + | |
1086 | 1080 | | |
1087 | 1081 | | |
1088 | 1082 | | |
| |||
Lines changed: 2 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | | - | |
35 | | - | |
36 | | - | |
| 34 | + | |
37 | 35 | | |
38 | 36 | | |
39 | | - | |
| 37 | + | |
40 | 38 | | |
41 | 39 | | |
42 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
659 | 659 | | |
660 | 660 | | |
661 | 661 | | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
662 | 665 | | |
663 | 666 | | |
664 | 667 | | |
| |||
671 | 674 | | |
672 | 675 | | |
673 | 676 | | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
674 | 680 | | |
675 | 681 | | |
676 | 682 | | |
| 683 | + | |
677 | 684 | | |
678 | 685 | | |
679 | 686 | | |
| |||
683 | 690 | | |
684 | 691 | | |
685 | 692 | | |
686 | | - | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
687 | 696 | | |
688 | 697 | | |
689 | 698 | | |
| |||
692 | 701 | | |
693 | 702 | | |
694 | 703 | | |
695 | | - | |
| 704 | + | |
| 705 | + | |
696 | 706 | | |
697 | 707 | | |
698 | 708 | | |
699 | 709 | | |
700 | 710 | | |
701 | 711 | | |
702 | 712 | | |
703 | | - | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
704 | 719 | | |
705 | 720 | | |
706 | 721 | | |
| |||
809 | 824 | | |
810 | 825 | | |
811 | 826 | | |
812 | | - | |
| 827 | + | |
813 | 828 | | |
814 | 829 | | |
815 | 830 | | |
| |||
828 | 843 | | |
829 | 844 | | |
830 | 845 | | |
831 | | - | |
832 | 846 | | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
833 | 850 | | |
834 | 851 | | |
835 | 852 | | |
| |||
841 | 858 | | |
842 | 859 | | |
843 | 860 | | |
844 | | - | |
| 861 | + | |
845 | 862 | | |
846 | 863 | | |
847 | 864 | | |
| |||
857 | 874 | | |
858 | 875 | | |
859 | 876 | | |
860 | | - | |
| 877 | + | |
861 | 878 | | |
862 | 879 | | |
863 | | - | |
| 880 | + | |
864 | 881 | | |
865 | 882 | | |
866 | 883 | | |
867 | 884 | | |
868 | 885 | | |
869 | 886 | | |
870 | | - | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
871 | 895 | | |
872 | 896 | | |
873 | 897 | | |
874 | 898 | | |
875 | | - | |
876 | 899 | | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
877 | 903 | | |
878 | 904 | | |
879 | 905 | | |
| |||
891 | 917 | | |
892 | 918 | | |
893 | 919 | | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
894 | 923 | | |
895 | 924 | | |
896 | 925 | | |
| |||
912 | 941 | | |
913 | 942 | | |
914 | 943 | | |
915 | | - | |
| 944 | + | |
916 | 945 | | |
917 | 946 | | |
918 | 947 | | |
919 | 948 | | |
920 | 949 | | |
921 | 950 | | |
922 | | - | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
923 | 957 | | |
924 | 958 | | |
925 | 959 | | |
| |||
1053 | 1087 | | |
1054 | 1088 | | |
1055 | 1089 | | |
| 1090 | + | |
1056 | 1091 | | |
1057 | 1092 | | |
1058 | 1093 | | |
| |||
1182 | 1217 | | |
1183 | 1218 | | |
1184 | 1219 | | |
| 1220 | + | |
1185 | 1221 | | |
1186 | 1222 | | |
1187 | 1223 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
132 | 132 | | |
133 | 133 | | |
134 | 134 | | |
135 | | - | |
136 | 135 | | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
137 | 139 | | |
138 | 140 | | |
139 | 141 | | |
140 | | - | |
| 142 | + | |
141 | 143 | | |
142 | 144 | | |
143 | 145 | | |
| |||
163 | 165 | | |
164 | 166 | | |
165 | 167 | | |
166 | | - | |
| 168 | + | |
167 | 169 | | |
168 | 170 | | |
169 | 171 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
169 | 169 | | |
170 | 170 | | |
171 | 171 | | |
172 | | - | |
| 172 | + | |
173 | 173 | | |
174 | 174 | | |
175 | 175 | | |
| |||
0 commit comments