Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 70 additions & 56 deletions scripts/fix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
It adds @unittest.expectedFailure to the test functions that are failing in RustPython, but not in CPython.
As well as marking the test with a TODO comment.

How to use:
Quick Import (recommended):
python ./scripts/fix_test.py --quick-import cpython/Lib/test/test_foo.py

This will:
1. Copy cpython/Lib/test/test_foo.py to Lib/test/test_foo.py (if not exists)
2. Run the test with RustPython
3. Mark failing tests with @unittest.expectedFailure

Manual workflow:
1. Copy a specific test from the CPython repository to the RustPython repository.
2. Remove all unexpected failures from the test and skip the tests that hang.
3. Build RustPython: cargo build --release
Expand All @@ -15,16 +23,23 @@
"""

import argparse
import ast
import itertools
import platform
import shutil
import sys
from pathlib import Path

from lib_updater import apply_patches, PatchSpec, UtMethod


def parse_args():
parser = argparse.ArgumentParser(description="Fix test.")
parser.add_argument("--path", type=Path, help="Path to test file")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--path", type=Path, help="Path to test file")
group.add_argument(
"--quick-import",
type=Path,
metavar="PATH",
help="Import from path containing /Lib/ (e.g., cpython/Lib/test/foo.py)",
)
parser.add_argument("--force", action="store_true", help="Force modification")
parser.add_argument(
"--platform", action="store_true", help="Platform specific failure"
Expand Down Expand Up @@ -102,39 +117,16 @@ def path_to_test(path) -> list[str]:
return parts[-2:] # Get class name and method name


def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None:
"""Find the line number and column offset of a test function.
Returns (lineno, col_offset) or None if not found.
"""
a = ast.parse(file)
for key, node in ast.iter_fields(a):
if key == "body":
for n in node:
match n:
case ast.ClassDef():
if len(test) == 2 and test[0] == n.name:
for fn in n.body:
match fn:
case ast.FunctionDef() | ast.AsyncFunctionDef():
if fn.name == test[-1]:
return (fn.lineno, fn.col_offset)
case ast.FunctionDef() | ast.AsyncFunctionDef():
if n.name == test[0] and len(test) == 1:
return (n.lineno, n.col_offset)
return None


def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str:
"""Apply all modifications in reverse order to avoid line number offset issues."""
lines = file.splitlines()
fixture = "@unittest.expectedFailure"
# Sort by line number in descending order
modifications.sort(key=lambda x: x[0], reverse=True)
for lineno, col_offset in modifications:
indent = " " * col_offset
lines.insert(lineno - 1, indent + fixture)
lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON")
return "\n".join(lines)
def build_patches(test_parts_set: set[tuple[str, str]]) -> dict:
"""Convert failing tests to lib_updater patch format."""
patches = {}
for class_name, method_name in test_parts_set:
if class_name not in patches:
patches[class_name] = {}
patches[class_name][method_name] = [
PatchSpec(UtMethod.ExpectedFailure, None, "")
]
return patches


def run_test(test_name):
Expand All @@ -146,7 +138,7 @@ def run_test(test_name):
import subprocess

result = subprocess.run(
[rustpython_location, "-m", "test", "-v", test_name],
[rustpython_location, "-m", "test", "-v", "-u", "all", "--slowest", test_name],
capture_output=True,
text=True,
)
Expand All @@ -155,6 +147,33 @@ def run_test(test_name):

if __name__ == "__main__":
args = parse_args()

# Handle --quick-import: extract Lib/... path and copy if needed
if args.quick_import is not None:
src_str = str(args.quick_import)
lib_marker = "/Lib/"

if lib_marker not in src_str:
print(f"Error: --quick-import path must contain '/Lib/' (got: {src_str})")
sys.exit(1)

idx = src_str.index(lib_marker)
lib_path = Path(src_str[idx + 1 :]) # Lib/test/foo.py
src_path = args.quick_import

if not src_path.exists():
print(f"Error: Source file not found: {src_path}")
sys.exit(1)

if not lib_path.exists():
print(f"Copying: {src_path} -> {lib_path}")
lib_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(src_path, lib_path)
else:
print(f"File already exists: {lib_path}")

args.path = lib_path

test_path = args.path.resolve()
if not test_path.exists():
print(f"Error: File not found: {test_path}")
Expand All @@ -167,26 +186,21 @@ def run_test(test_name):
tests = run_test(test_name)
f = test_path.read_text(encoding="utf-8")

# Collect all modifications first (with deduplication for subtests)
modifications = []
# Collect failing tests (with deduplication for subtests)
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
for test in tests.tests:
if test.result == "fail" or test.result == "error":
test_parts = path_to_test(test.path)
test_key = tuple(test_parts)
if test_key in seen_tests:
continue # Skip duplicate (same test, different subtest)
seen_tests.add(test_key)
location = find_test_lineno(f, test_parts)
if location:
print(f"Modifying test: {test.name} at line {location[0]}")
modifications.append(location)
else:
print(f"Warning: Could not find test: {test.name} ({test_parts})")

# Apply all modifications in reverse order
if modifications:
f = apply_modifications(f, modifications)
if len(test_parts) == 2:
test_key = tuple(test_parts)
if test_key not in seen_tests:
seen_tests.add(test_key)
print(f"Marking test: {test_parts[0]}.{test_parts[1]}")

# Apply patches using lib_updater
if seen_tests:
patches = build_patches(seen_tests)
f = apply_patches(f, patches)
test_path.write_text(f, encoding="utf-8")

print(f"Modified {len(modifications)} tests")
print(f"Modified {len(seen_tests)} tests")