Skip to content

Commit 64024ac

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Fix CUDA flaky tests for stochastic gates by using CPU-seeded RNG
Summary: ## Problem CUDA RNG produces different random sequences on different GPU architectures (e.g. V100 vs A100 vs H100) even with the same seed set via `torch.manual_seed()`. This causes stochastic gate CUDA tests to be flaky in CI — the same test passes on one GPU type but fails on another because expected values were hardcoded for a specific architecture's RNG output. Additionally, `test_p_norm_decay` uses exact `assert ==` for floating-point tensor comparison, which fails on GPU due to floating-point precision differences. ## Solution **CPU-seeded RNG approach**: In CUDA test subclasses, patch `_sample_gate_values` to generate random noise on CPU (where `torch.manual_seed` is deterministic across all hardware) and then move the tensor to the GPU device. This keeps the full training codepath exercised (noise + mu → clamp → gather → multiply) while ensuring cross-architecture determinism. For `LazyGaussianStochasticGates`, both `initialize_parameters` (mu initialization) and `_sample_gate_values` (noise sampling) happen on-device after `.to(cuda)`, so both are patched to use CPU RNG. Since CUDA tests now produce identical values to CPU tests, the `if cpu / elif cuda` branches in base test files are removed, along with associated `pyre-fixme[61]` comments. For `test_p_norm_decay`, exact `assert ==` is replaced with `assertTensorAlmostEqual` with `delta=0.01` tolerance. ## Files Changed - `test_gaussian_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `normal_()` sampling - `test_kuma_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `uniform_()` sampling + Kumaraswamy transform - `test_lazy_gaussian_stochastic_gates_cuda.py`: Patch both `initialize_parameters` and `_sample_gate_values` - `test_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (4 tests) - `test_kuma_stochastic_gates.py`: Remove cpu/cuda branches (4 tests) - `test_lazy_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (12 tests) - `test_p_norm_decay.py`: Use `assertTensorAlmostEqual` instead of exact equality (2 tests) Differential Revision: D97775614
1 parent 458e134 commit 64024ac

4 files changed

Lines changed: 111 additions & 71 deletions

tests/module/test_binary_concrete_stochastic_gates.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,8 @@ def test_bcstg_1d_input(self) -> None:
3232
gated_input, reg = bcstg(input_tensor)
3333
expected_reg = 2.4947
3434

35-
if self.testing_device == "cpu":
36-
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
37-
elif self.testing_device == "cuda":
38-
expected_gated_input = [[0.0000, 0.0985, 0.1149], [0.2329, 0.0497, 0.5000]]
35+
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
3936

40-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
4137
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4238
assertTensorAlmostEqual(self, reg, expected_reg)
4339

@@ -110,12 +106,8 @@ def test_bcstg_1d_input_with_mask(self) -> None:
110106
gated_input, reg = bcstg(input_tensor)
111107
expected_reg = 1.6643
112108

113-
if self.testing_device == "cpu":
114-
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
115-
elif self.testing_device == "cuda":
116-
expected_gated_input = [[0.0000, 0.0000, 0.1971], [0.1737, 0.2317, 0.3888]]
109+
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
117110

118-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
119111
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
120112
assertTensorAlmostEqual(self, reg, expected_reg)
121113

@@ -143,18 +135,10 @@ def test_bcstg_2d_input(self) -> None:
143135
gated_input, reg = bcstg(input_tensor)
144136

145137
expected_reg = 4.9903
146-
expected_gated_input = []
147-
148-
if self.testing_device == "cpu":
149-
expected_gated_input = [
150-
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
151-
[[0.0476, 0.6177], [0.5400, 0.1530], [0.0984, 0.8013]],
152-
]
153-
elif self.testing_device == "cuda":
154-
expected_gated_input = [
155-
[[0.0000, 0.0985], [0.1149, 0.2331], [0.0486, 0.5000]],
156-
[[0.1840, 0.1571], [0.4612, 0.7937], [0.2975, 0.7393]],
157-
]
138+
expected_gated_input = [
139+
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
140+
[[0.0476, 0.6177], [0.5400, 0.1530], [0.0984, 0.8013]],
141+
]
158142

159143
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
160144
assertTensorAlmostEqual(self, reg, expected_reg)
@@ -207,18 +191,11 @@ def test_bcstg_2d_input_with_mask(self) -> None:
207191
gated_input, reg = bcstg(input_tensor)
208192
expected_reg = 2.4947
209193

210-
if self.testing_device == "cpu":
211-
expected_gated_input = [
212-
[[0.0000, 0.0212], [0.0424, 0.0636], [0.3191, 0.4730]],
213-
[[0.3678, 0.6568], [0.7507, 0.8445], [0.6130, 1.0861]],
214-
]
215-
elif self.testing_device == "cuda":
216-
expected_gated_input = [
217-
[[0.0000, 0.0985], [0.1971, 0.2956], [0.0000, 0.2872]],
218-
[[0.4658, 0.0870], [0.0994, 0.1119], [0.7764, 1.1000]],
219-
]
194+
expected_gated_input = [
195+
[[0.0000, 0.0212], [0.0424, 0.0636], [0.3191, 0.4730]],
196+
[[0.3678, 0.6568], [0.7507, 0.8445], [0.6130, 1.0861]],
197+
]
220198

221-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
222199
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
223200
assertTensorAlmostEqual(self, reg, expected_reg)
224201

tests/module/test_binary_concrete_stochastic_gates_cuda.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,59 @@
33

44
# pyre-strict
55

6-
from .test_binary_concrete_stochastic_gates import TestBinaryConcreteStochasticGates
6+
import unittest
7+
from unittest.mock import patch
78

9+
import torch
10+
from captum.module.binary_concrete_stochastic_gates import (
11+
BinaryConcreteStochasticGates,
12+
)
13+
from torch import Tensor
814

9-
class TestBinaryConcreteStochasticGatesCUDA(TestBinaryConcreteStochasticGates):
15+
from .test_binary_concrete_stochastic_gates import (
16+
TestBinaryConcreteStochasticGates,
17+
)
18+
19+
20+
# CUDA RNG produces different sequences on different GPU architectures
21+
# (e.g. V100 vs A100 vs H100) even with the same seed, causing flaky
22+
# tests. By generating uniform samples on CPU and moving to the device,
23+
# tests get consistent results regardless of which GPU type runs them.
24+
def _cpu_rng_sample(
25+
self: BinaryConcreteStochasticGates, batch_size: int
26+
) -> Tensor:
27+
if self.training:
28+
u = torch.empty(batch_size, self.n_gates)
29+
u.uniform_(self.eps, 1 - self.eps)
30+
u = u.to(self.log_alpha_param.device)
31+
s = torch.sigmoid(
32+
(torch.logit(u) + self.log_alpha_param) / self.temperature
33+
)
34+
else:
35+
s = torch.sigmoid(self.log_alpha_param)
36+
s = s.expand(batch_size, self.n_gates)
37+
38+
s_bar = s * (self.upper_bound - self.lower_bound) + self.lower_bound
39+
return s_bar
40+
41+
42+
class TestBinaryConcreteStochasticGatesCUDA(
43+
TestBinaryConcreteStochasticGates,
44+
):
1045
testing_device: str = "cuda"
46+
47+
def setUp(self) -> None:
48+
super().setUp()
49+
if not torch.cuda.is_available():
50+
raise unittest.SkipTest(
51+
"Skipping GPU test since CUDA not available."
52+
)
53+
# pyre-fixme[8]: Attribute has type
54+
# `BoundMethod[..., Tensor]`; used as `(...) -> Tensor`.
55+
patcher = patch.object(
56+
BinaryConcreteStochasticGates,
57+
"_sample_gate_values",
58+
_cpu_rng_sample,
59+
)
60+
patcher.start()
61+
self.addCleanup(patcher.stop)

tests/module/test_gaussian_stochastic_gates.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ def test_gstg_1d_input(self) -> None:
3333

3434
gated_input, reg = gstg(input_tensor)
3535
expected_reg = 2.5213
36+
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
3637

37-
if self.testing_device == "cpu":
38-
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
39-
elif self.testing_device == "cuda":
40-
expected_gated_input = [[0.0000, 0.0788, 0.0470], [0.0134, 0.0000, 0.1884]]
41-
42-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
4338
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4439
assertTensorAlmostEqual(self, reg, expected_reg)
4540

@@ -90,13 +85,8 @@ def test_gstg_1d_input_with_mask(self) -> None:
9085

9186
gated_input, reg = gstg(input_tensor)
9287
expected_reg = 1.6849
88+
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
9389

94-
if self.testing_device == "cpu":
95-
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
96-
elif self.testing_device == "cuda":
97-
expected_gated_input = [[0.0000, 0.0000, 0.1577], [0.0736, 0.0981, 0.0242]]
98-
99-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
10090
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
10191
assertTensorAlmostEqual(self, reg, expected_reg)
10292

@@ -137,19 +127,11 @@ def test_gstg_2d_input(self) -> None:
137127

138128
gated_input, reg = gstg(input_tensor)
139129
expected_reg = 5.0458
130+
expected_gated_input = [
131+
[[0.0000, 0.0851], [0.0713, 0.3000], [0.2180, 0.1878]],
132+
[[0.2538, 0.0000], [0.3391, 0.8501], [0.3633, 0.8913]],
133+
]
140134

141-
if self.testing_device == "cpu":
142-
expected_gated_input = [
143-
[[0.0000, 0.0851], [0.0713, 0.3000], [0.2180, 0.1878]],
144-
[[0.2538, 0.0000], [0.3391, 0.8501], [0.3633, 0.8913]],
145-
]
146-
elif self.testing_device == "cuda":
147-
expected_gated_input = [
148-
[[0.0000, 0.0788], [0.0470, 0.0139], [0.0000, 0.1960]],
149-
[[0.0000, 0.7000], [0.1052, 0.2120], [0.5978, 0.0166]],
150-
]
151-
152-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
153135
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
154136
assertTensorAlmostEqual(self, reg, expected_reg)
155137

@@ -200,19 +182,11 @@ def test_gstg_2d_input_with_mask(self) -> None:
200182

201183
gated_input, reg = gstg(input_tensor)
202184
expected_reg = 2.5213
185+
expected_gated_input = [
186+
[[0.0000, 0.0198], [0.0396, 0.0594], [0.2435, 0.3708]],
187+
[[0.3696, 0.5954], [0.6805, 0.7655], [0.6159, 0.3921]],
188+
]
203189

204-
if self.testing_device == "cpu":
205-
expected_gated_input = [
206-
[[0.0000, 0.0198], [0.0396, 0.0594], [0.2435, 0.3708]],
207-
[[0.3696, 0.5954], [0.6805, 0.7655], [0.6159, 0.3921]],
208-
]
209-
elif self.testing_device == "cuda":
210-
expected_gated_input = [
211-
[[0.0000, 0.0788], [0.1577, 0.2365], [0.0000, 0.1174]],
212-
[[0.0269, 0.0000], [0.0000, 0.0000], [0.0448, 0.4145]],
213-
]
214-
215-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
216190
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
217191
assertTensorAlmostEqual(self, reg, expected_reg)
218192

tests/module/test_gaussian_stochastic_gates_cuda.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,46 @@
33

44
# pyre-strict
55

6+
import unittest
7+
from unittest.mock import patch
8+
9+
import torch
10+
from captum.module.gaussian_stochastic_gates import GaussianStochasticGates
11+
from torch import Tensor
12+
613
from .test_gaussian_stochastic_gates import TestGaussianStochasticGates
714

815

16+
# CUDA RNG produces different sequences on different GPU architectures
17+
# (e.g. V100 vs A100 vs H100) even with the same seed, causing flaky tests.
18+
# By generating noise on CPU (where torch.manual_seed is deterministic across
19+
# all hardware) and moving to the device, tests get consistent results
20+
# regardless of which GPU type runs them in CI.
21+
def _cpu_rng_sample(
22+
self: GaussianStochasticGates, batch_size: int
23+
) -> Tensor:
24+
if self.training:
25+
n = torch.empty(batch_size, self.n_gates)
26+
n.normal_(mean=0, std=self.std)
27+
return self.mu + n.to(self.mu.device)
28+
return self.mu.expand(batch_size, self.n_gates)
29+
30+
931
class TestGaussianStochasticGatesCUDA(TestGaussianStochasticGates):
1032
testing_device: str = "cuda"
33+
34+
def setUp(self) -> None:
35+
super().setUp()
36+
if not torch.cuda.is_available():
37+
raise unittest.SkipTest(
38+
"Skipping GPU test since CUDA not available."
39+
)
40+
# pyre-fixme[8]: Attribute has type
41+
# `BoundMethod[..., Tensor]`; used as `(...) -> Tensor`.
42+
patcher = patch.object(
43+
GaussianStochasticGates,
44+
"_sample_gate_values",
45+
_cpu_rng_sample,
46+
)
47+
patcher.start()
48+
self.addCleanup(patcher.stop)

0 commit comments

Comments
 (0)