Skip to content

Commit 1682b1a

Browse files
committed
gh-138122: Make sampling profiler integration tests more resilient
The tests were flaky on slow machines because subprocesses could finish before enough samples were collected. This adds synchronization similar to test_external_inspection: test scripts now signal when they start working, and the profiler waits for this signal before sampling. Test scripts now run in infinite loops until killed rather than for fixed iterations, ensuring the profiler always has active work to sample regardless of machine speed.
1 parent d6d850d commit 1682b1a

File tree

2 files changed

+150
-57
lines changed

2 files changed

+150
-57
lines changed

Lib/test/test_profiling/test_sampling_profiler/helpers.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,87 @@
3838
SubprocessInfo = namedtuple("SubprocessInfo", ["process", "socket"])
3939

4040

41+
def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT):
42+
"""
43+
Wait for expected signal(s) from a socket with proper timeout and EOF handling.
44+
45+
Args:
46+
sock: Connected socket to read from
47+
expected_signals: Single bytes object or list of bytes objects to wait for
48+
timeout: Socket timeout in seconds
49+
50+
Returns:
51+
bytes: Complete accumulated response buffer
52+
53+
Raises:
54+
RuntimeError: If connection closed before signal received or timeout
55+
"""
56+
if isinstance(expected_signals, bytes):
57+
expected_signals = [expected_signals]
58+
59+
sock.settimeout(timeout)
60+
buffer = b""
61+
62+
while True:
63+
# Check if all expected signals are in buffer
64+
if all(sig in buffer for sig in expected_signals):
65+
return buffer
66+
67+
try:
68+
chunk = sock.recv(4096)
69+
if not chunk:
70+
raise RuntimeError(
71+
f"Connection closed before receiving expected signals. "
72+
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
73+
)
74+
buffer += chunk
75+
except socket.timeout:
76+
raise RuntimeError(
77+
f"Timeout waiting for signals. "
78+
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
79+
) from None
80+
except OSError as e:
81+
raise RuntimeError(
82+
f"Socket error while waiting for signals: {e}. "
83+
f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
84+
) from None
85+
86+
87+
def _cleanup_sockets(*sockets):
88+
"""Safely close multiple sockets, ignoring errors."""
89+
for sock in sockets:
90+
if sock is not None:
91+
try:
92+
sock.close()
93+
except OSError:
94+
pass
95+
96+
97+
def _cleanup_process(proc, timeout=SHORT_TIMEOUT):
98+
"""Terminate a process gracefully, escalating to kill if needed."""
99+
if proc.poll() is not None:
100+
return
101+
proc.terminate()
102+
try:
103+
proc.wait(timeout=timeout)
104+
return
105+
except subprocess.TimeoutExpired:
106+
pass
107+
proc.kill()
108+
try:
109+
proc.wait(timeout=timeout)
110+
except subprocess.TimeoutExpired:
111+
pass # Process refuses to die, nothing more we can do
112+
113+
41114
@contextlib.contextmanager
42-
def test_subprocess(script):
115+
def test_subprocess(script, wait_for_working=True):
43116
"""Context manager to create a test subprocess with socket synchronization.
44117
45118
Args:
46-
script: Python code to execute in the subprocess
119+
script: Python code to execute in the subprocess. Should send b"working"
120+
to signal when work has started.
121+
wait_for_working: If True, wait for both "ready" and "working" signals
47122
48123
Yields:
49124
SubprocessInfo: Named tuple with process and socket objects
@@ -80,19 +155,18 @@ def test_subprocess(script):
80155
# Wait for process to connect and send ready signal
81156
client_socket, _ = server_socket.accept()
82157
server_socket.close()
83-
response = client_socket.recv(1024)
84-
if response != b"ready":
85-
raise RuntimeError(
86-
f"Unexpected response from subprocess: {response!r}"
87-
)
158+
server_socket = None
159+
160+
# Wait for ready signal, and optionally working signal
161+
if wait_for_working:
162+
_wait_for_signal(client_socket, [b"ready", b"working"])
163+
else:
164+
_wait_for_signal(client_socket, b"ready")
88165

89166
yield SubprocessInfo(proc, client_socket)
90167
finally:
91-
if client_socket is not None:
92-
client_socket.close()
93-
if proc.poll() is None:
94-
proc.kill()
95-
proc.wait()
168+
_cleanup_sockets(client_socket, server_socket)
169+
_cleanup_process(proc)
96170

97171

98172
def close_and_unlink(file):

Lib/test/test_profiling/test_sampling_profiler/test_integration.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
# Duration for profiling tests - long enough for process to complete naturally
4040
PROFILING_TIMEOUT = str(int(SHORT_TIMEOUT))
4141

42+
# Duration for profiling in tests - short enough to complete quickly
43+
PROFILING_DURATION_SEC = 2
44+
4245

4346
@skip_if_not_supported
4447
@unittest.skipIf(
@@ -359,42 +362,49 @@ def total_occurrences(func):
359362
self.assertEqual(total_occurrences(main_key), 2)
360363

361364

362-
@requires_subprocess()
363-
@skip_if_not_supported
364-
class TestSampleProfilerIntegration(unittest.TestCase):
365-
@classmethod
366-
def setUpClass(cls):
367-
cls.test_script = '''
368-
import time
369-
import os
370-
365+
# Shared workload functions for test scripts
366+
_WORKLOAD_FUNCTIONS = '''
371367
def slow_fibonacci(n):
372-
"""Recursive fibonacci - should show up prominently in profiler."""
373368
if n <= 1:
374369
return n
375370
return slow_fibonacci(n-1) + slow_fibonacci(n-2)
376371
377372
def cpu_intensive_work():
378-
"""CPU intensive work that should show in profiler."""
379373
result = 0
380374
for i in range(10000):
381375
result += i * i
382376
if i % 100 == 0:
383377
result = result % 1000000
384378
return result
385379
386-
def main_loop():
387-
"""Main test loop."""
388-
max_iterations = 200
389-
390-
for iteration in range(max_iterations):
380+
def do_work():
381+
iteration = 0
382+
while True:
391383
if iteration % 2 == 0:
392-
result = slow_fibonacci(15)
384+
slow_fibonacci(15)
393385
else:
394-
result = cpu_intensive_work()
386+
cpu_intensive_work()
387+
iteration += 1
388+
'''
389+
395390

396-
if __name__ == "__main__":
397-
main_loop()
391+
@requires_subprocess()
392+
@skip_if_not_supported
393+
class TestSampleProfilerIntegration(unittest.TestCase):
394+
@classmethod
395+
def setUpClass(cls):
396+
# Test script for use with test_subprocess() - signals when work starts
397+
cls.test_script = _WORKLOAD_FUNCTIONS + '''
398+
_test_sock.sendall(b"working")
399+
do_work()
400+
'''
401+
# CLI test script - runs for fixed duration (no socket sync)
402+
cls.cli_test_script = '''
403+
import time
404+
''' + _WORKLOAD_FUNCTIONS.replace(
405+
'while True:', 'end_time = time.time() + 30\n while time.time() < end_time:'
406+
) + '''
407+
do_work()
398408
'''
399409

400410
def test_sampling_basic_functionality(self):
@@ -404,12 +414,11 @@ def test_sampling_basic_functionality(self):
404414
mock.patch("sys.stdout", captured_output),
405415
):
406416
try:
407-
# Sample for up to SHORT_TIMEOUT seconds, but process exits after fixed iterations
408417
collector = PstatsCollector(sample_interval_usec=1000, skip_idle=False)
409418
profiling.sampling.sample.sample(
410419
subproc.process.pid,
411420
collector,
412-
duration_sec=SHORT_TIMEOUT,
421+
duration_sec=PROFILING_DURATION_SEC,
413422
)
414423
collector.print_stats(show_summary=False)
415424
except PermissionError:
@@ -442,7 +451,7 @@ def test_sampling_with_pstats_export(self):
442451
profiling.sampling.sample.sample(
443452
subproc.process.pid,
444453
collector,
445-
duration_sec=1,
454+
duration_sec=PROFILING_DURATION_SEC,
446455
)
447456
collector.export(pstats_out.name)
448457
except PermissionError:
@@ -488,7 +497,7 @@ def test_sampling_with_collapsed_export(self):
488497
profiling.sampling.sample.sample(
489498
subproc.process.pid,
490499
collector,
491-
duration_sec=1,
500+
duration_sec=PROFILING_DURATION_SEC,
492501
)
493502
collector.export(collapsed_file.name)
494503
except PermissionError:
@@ -536,7 +545,7 @@ def test_sampling_all_threads(self):
536545
profiling.sampling.sample.sample(
537546
subproc.process.pid,
538547
collector,
539-
duration_sec=1,
548+
duration_sec=PROFILING_DURATION_SEC,
540549
all_threads=True,
541550
)
542551
collector.print_stats(show_summary=False)
@@ -548,12 +557,16 @@ def test_sampling_all_threads(self):
548557

549558
def test_sample_target_script(self):
550559
script_file = tempfile.NamedTemporaryFile(delete=False)
551-
script_file.write(self.test_script.encode("utf-8"))
560+
script_file.write(self.cli_test_script.encode("utf-8"))
552561
script_file.flush()
553562
self.addCleanup(close_and_unlink, script_file)
554563

555-
# Sample for up to SHORT_TIMEOUT seconds, but process exits after fixed iterations
556-
test_args = ["profiling.sampling.sample", "run", "-d", PROFILING_TIMEOUT, script_file.name]
564+
# Sample for PROFILING_DURATION_SEC seconds
565+
test_args = [
566+
"profiling.sampling.sample", "run",
567+
"-d", str(PROFILING_DURATION_SEC),
568+
script_file.name
569+
]
557570

558571
with (
559572
mock.patch("sys.argv", test_args),
@@ -583,13 +596,13 @@ def test_sample_target_module(self):
583596
module_path = os.path.join(tempdir.name, "test_module.py")
584597

585598
with open(module_path, "w") as f:
586-
f.write(self.test_script)
599+
f.write(self.cli_test_script)
587600

588601
test_args = [
589602
"profiling.sampling.cli",
590603
"run",
591604
"-d",
592-
PROFILING_TIMEOUT,
605+
str(PROFILING_DURATION_SEC),
593606
"-m",
594607
"test_module",
595608
]
@@ -630,8 +643,10 @@ def test_invalid_pid(self):
630643
profiling.sampling.sample.sample(-1, collector, duration_sec=1)
631644

632645
def test_process_dies_during_sampling(self):
646+
# Use wait_for_working=False since this simple script doesn't send "working"
633647
with test_subprocess(
634-
"import time; time.sleep(0.5); exit()"
648+
"import time; time.sleep(0.5); exit()",
649+
wait_for_working=False
635650
) as subproc:
636651
with (
637652
io.StringIO() as captured_output,
@@ -654,7 +669,11 @@ def test_process_dies_during_sampling(self):
654669
self.assertIn("Error rate", output)
655670

656671
def test_is_process_running(self):
657-
with test_subprocess("import time; time.sleep(1000)") as subproc:
672+
# Use wait_for_working=False since this simple script doesn't send "working"
673+
with test_subprocess(
674+
"import time; time.sleep(1000)",
675+
wait_for_working=False
676+
) as subproc:
658677
try:
659678
profiler = SampleProfiler(
660679
pid=subproc.process.pid,
@@ -681,7 +700,11 @@ def test_is_process_running(self):
681700

682701
@unittest.skipUnless(sys.platform == "linux", "Only valid on Linux")
683702
def test_esrch_signal_handling(self):
684-
with test_subprocess("import time; time.sleep(1000)") as subproc:
703+
# Use wait_for_working=False since this simple script doesn't send "working"
704+
with test_subprocess(
705+
"import time; time.sleep(1000)",
706+
wait_for_working=False
707+
) as subproc:
685708
try:
686709
unwinder = _remote_debugging.RemoteUnwinder(
687710
subproc.process.pid
@@ -793,38 +816,34 @@ class TestAsyncAwareProfilingIntegration(unittest.TestCase):
793816

794817
@classmethod
795818
def setUpClass(cls):
819+
# Async test script that runs indefinitely until killed.
820+
# Sends "working" signal AFTER tasks are created and scheduled.
796821
cls.async_script = '''
797822
import asyncio
798823
799824
async def sleeping_leaf():
800-
"""Leaf task that just sleeps - visible in 'all' mode."""
801-
for _ in range(50):
825+
while True:
802826
await asyncio.sleep(0.02)
803827
804828
async def cpu_leaf():
805-
"""Leaf task that does CPU work - visible in both modes."""
806829
total = 0
807-
for _ in range(200):
830+
while True:
808831
for i in range(10000):
809832
total += i * i
810833
await asyncio.sleep(0)
811-
return total
812834
813835
async def supervisor():
814-
"""Middle layer that spawns leaf tasks."""
815836
tasks = [
816837
asyncio.create_task(sleeping_leaf(), name="Sleeper-0"),
817838
asyncio.create_task(sleeping_leaf(), name="Sleeper-1"),
818839
asyncio.create_task(sleeping_leaf(), name="Sleeper-2"),
819840
asyncio.create_task(cpu_leaf(), name="Worker"),
820841
]
842+
await asyncio.sleep(0) # Let tasks get scheduled
843+
_test_sock.sendall(b"working")
821844
await asyncio.gather(*tasks)
822845
823-
async def main():
824-
await supervisor()
825-
826-
if __name__ == "__main__":
827-
asyncio.run(main())
846+
asyncio.run(supervisor())
828847
'''
829848

830849
def _collect_async_samples(self, async_aware_mode):
@@ -838,7 +857,7 @@ def _collect_async_samples(self, async_aware_mode):
838857
profiling.sampling.sample.sample(
839858
subproc.process.pid,
840859
collector,
841-
duration_sec=SHORT_TIMEOUT,
860+
duration_sec=PROFILING_DURATION_SEC,
842861
async_aware=async_aware_mode,
843862
)
844863
except PermissionError:

0 commit comments

Comments
 (0)