Skip to content

Commit 4cf10bf

Browse files
[Cherry-pick] [Quant] [PT2] Enable batchnorm in _move_exported_model_to_eval (#115715)
1 parent 7e97e4b commit 4cf10bf

File tree

5 files changed

+139
-8
lines changed

5 files changed

+139
-8
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,43 @@ def forward(self, x):
17591759
check_dynamic=True,
17601760
)
17611761

1762+
@skipIfNoDynamoSupport
1763+
@skipIfNoONEDNN
1764+
@skipIfRocm
1765+
def test_qat_bn_conv2d(self):
1766+
r"""
1767+
This testcase will quantize a single BN Conv2d module with qat flow.
1768+
"""
1769+
1770+
class M(torch.nn.Module):
1771+
def __init__(
1772+
self,
1773+
):
1774+
super().__init__()
1775+
self.conv = torch.nn.Conv2d(3, 3, 3)
1776+
self.bn1 = torch.nn.BatchNorm2d(3)
1777+
self.bn2 = torch.nn.BatchNorm2d(3)
1778+
1779+
def forward(self, x):
1780+
x = self.conv(self.bn1(x))
1781+
return self.bn2(x)
1782+
1783+
mod = M().train()
1784+
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
1785+
1786+
def matcher_check_fn():
1787+
self.assertEqual(
1788+
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
1789+
)
1790+
1791+
self._test_common(
1792+
mod,
1793+
(v,),
1794+
check_quantization=True,
1795+
is_qat=True,
1796+
matcher_check_fn=matcher_check_fn,
1797+
)
1798+
17621799

17631800
if __name__ == "__main__":
17641801
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,42 @@ def test_move_exported_model_to_eval(self):
16521652
self._test_move_exported_model_to_eval_dropout(inplace=False)
16531653
self._test_move_exported_model_to_eval_dropout(inplace=True)
16541654

1655+
def test_bn_move_exported_model_to_eval(self):
1656+
class M(torch.nn.Module):
1657+
def __init__(
1658+
self,
1659+
):
1660+
super().__init__()
1661+
self.bn = torch.nn.BatchNorm2d(3)
1662+
self.conv = torch.nn.Conv2d(3, 3, 3)
1663+
1664+
def forward(self, x):
1665+
return self.conv(self.bn(x))
1666+
1667+
m = M().train()
1668+
example_inputs = (
1669+
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1),
1670+
)
1671+
1672+
m = capture_pre_autograd_graph(m, example_inputs)
1673+
1674+
# Assert that bn op exists and is in train mode
1675+
batch_norm_node = None
1676+
for n in m.graph.nodes:
1677+
if n.target == torch.ops.aten._native_batch_norm_legit.default:
1678+
batch_norm_node = n
1679+
break
1680+
self.assertTrue(batch_norm_node is not None)
1681+
self.assertTrue(batch_norm_node.args[5])
1682+
1683+
# Do the subgraph rewriting
1684+
torch.ao.quantization.move_exported_model_to_eval(m)
1685+
1686+
# Assert that bn op is now in eval mode
1687+
targets = [n.target for n in m.graph.nodes]
1688+
self.assertTrue(torch.ops.aten._native_batch_norm_legit.default not in targets)
1689+
self.assertTrue(torch.ops.aten._native_batch_norm_legit_no_training.default in targets)
1690+
16551691
def test_disallow_eval_train(self):
16561692
m = TestHelperModules.ConvWithBNRelu(relu=True)
16571693
example_inputs = (torch.rand(3, 3, 5, 5),)

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper(
159159
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
160160

161161
if verify_convert:
162+
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
162163
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
163164
model_pt2e = convert_pt2e(model_pt2e)
164165
quant_result_pt2e = model_pt2e(*example_inputs)

torch/ao/quantization/pt2e/eval_utils.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,68 @@ def dropout_eval(x):
4545
m.recompile()
4646

4747

48+
def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
49+
# TODO(Leslie): This function still fails to support custom momentum and eps value.
50+
# Enable this support in future updates.
51+
52+
# Avoid circular dependencies
53+
from .utils import get_aten_graph_module
54+
55+
# Needed to ensure subgraph matches are self-contained
56+
m.graph.eliminate_dead_code()
57+
m.recompile()
58+
59+
def bn_train(
60+
x: torch.Tensor,
61+
bn_weight: torch.Tensor,
62+
bn_bias: torch.Tensor,
63+
bn_running_mean: torch.Tensor,
64+
bn_running_var: torch.Tensor,
65+
):
66+
return F.batch_norm(
67+
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
68+
)
69+
70+
def bn_eval(
71+
x: torch.Tensor,
72+
bn_weight: torch.Tensor,
73+
bn_bias: torch.Tensor,
74+
bn_running_mean: torch.Tensor,
75+
bn_running_var: torch.Tensor,
76+
):
77+
return F.batch_norm(
78+
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
79+
)
80+
81+
example_inputs = (
82+
torch.randn(1, 1, 3, 3), # x
83+
torch.randn(1), # bn_weight
84+
torch.randn(1), # bn_bias
85+
torch.randn(1), # bn_running_mean
86+
torch.randn(1), # bn_running_var
87+
)
88+
match_pattern = get_aten_graph_module(bn_train, example_inputs)
89+
replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
90+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
91+
92+
replace_pattern_with_filters(
93+
m,
94+
match_pattern,
95+
replacement_pattern,
96+
match_filters=[],
97+
ignore_literals=True,
98+
)
99+
m.recompile()
100+
101+
48102
# TODO: also support move_exported_model_to_train
49-
# TODO: also support standalone batchnorm
50103
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
51104
"""
52105
Move an exported GraphModule to eval mode.
53106
54-
This is equivalent to model.eval() but only for certain special ops like dropout.
107+
This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
55108
QAT users should call this before performing inference on the model.
56109
"""
57110
_replace_dropout_for_eval(model)
111+
_replace_batchnorm_for_eval(model)
58112
return model

torch/ao/quantization/pt2e/qat_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _get_quantized_qat_conv_bn_pattern(
180180
has_bias: bool,
181181
bias_is_quantized: bool,
182182
conv_fn: Callable,
183+
bn_is_training: bool,
183184
) -> Callable:
184185
"""
185186
Return the quantized version of QAT conv + BN pattern.
@@ -218,7 +219,7 @@ def _quantized_qat_conv_bn_pattern(
218219
x = x / scale_factor.reshape(bias_shape)
219220
if has_bias:
220221
x = x + kwargs["conv_bias"].reshape(bias_shape)
221-
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
222+
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps)
222223
return x
223224
return _quantized_qat_conv_bn_pattern
224225

@@ -227,6 +228,7 @@ def _get_folded_quantized_qat_conv_bn_pattern(
227228
has_bias: bool,
228229
bias_is_quantized: bool,
229230
conv_fn: Callable,
231+
bn_is_training: bool,
230232
) -> Callable:
231233
"""
232234
Quantized QAT conv - bn pattern with bn weights being folded into conv.
@@ -251,7 +253,7 @@ def _folded_quantized_qat_conv_bn_pattern(
251253
else:
252254
bias = None
253255
x = conv_fn(x, conv_weight, bias)
254-
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
256+
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps)
255257
return x
256258
return _folded_quantized_qat_conv_bn_pattern
257259

@@ -322,7 +324,7 @@ def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
322324
if _is_conv(n):
323325
assert conv_node is None
324326
conv_node = n
325-
if _is_supported_batch_norm_for_training(n):
327+
if _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
326328
assert bn_node is None
327329
bn_node = n
328330
if n.target == operator.getitem:
@@ -715,19 +717,20 @@ def _fold_conv_bn_qat_helper(
715717
[True, False], # is_per_channel
716718
[True, False], # has_bias
717719
[True, False], # bias_is_quantized
720+
[True, False], # bn_is_training
718721
)
719-
for is_per_channel, has_bias, bias_is_quantized in replacement_options:
722+
for is_per_channel, has_bias, bias_is_quantized, bn_is_training in replacement_options:
720723
# For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
721724
# filter out one of the values for this flag to avoid having duplicate patterns
722725
if not has_bias and bias_is_quantized:
723726
continue
724727
kwargs = _get_quantized_conv_bn_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
725728
match_pattern = _get_quantized_qat_conv_bn_pattern(
726-
is_per_channel, has_bias, bias_is_quantized, conv_fn,
729+
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
727730
)
728731
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
729732
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
730-
is_per_channel, has_bias, bias_is_quantized, conv_fn
733+
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
731734
)
732735
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
733736
replacements.extend(

0 commit comments

Comments
 (0)