Skip to content
Prev Previous commit
Next Next commit
Cleanup code a bit
  • Loading branch information
ShaharNaveh committed Sep 9, 2025
commit e0b79841f1fb59885f1633648c2bfbf5d9fc4834
68 changes: 35 additions & 33 deletions scripts/lib_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
INDENT1 = " " * COL_OFFSET
INDENT2 = INDENT1 * 2
COMMENT = "TODO: RUSTPYTHON"
UT = "unittest"


@enum.unique
Expand Down Expand Up @@ -87,15 +88,15 @@ class PatchSpec(typing.NamedTuple):
def as_decorator(self) -> str:
reason = f"{COMMENT}; {self.reason}".strip(" ;")
if not self.ut_method.has_args():
return f"@unittest.{self.ut_method} # {reason}"
return f"@{UT}.{self.ut_method} # {reason}"

args = []
if self.cond:
args.append(ast.parse(self.cond).body[0].value)
args.append(ast.Constant(value=reason))

call_node = ast.Call(
func=ast.Attribute(value=ast.Name(id="unittest"), attr=self.ut_method),
func=ast.Attribute(value=ast.Name(id=UT), attr=self.ut_method),
args=args,
keywords=[],
)
Expand Down Expand Up @@ -137,7 +138,7 @@ def iter_patch_entires(

if (
isinstance(attr_node, ast.Name)
or getattr(attr_node.value, "id", None) != "unittest"
or getattr(attr_node.value, "id", None) != UT
):
continue

Expand All @@ -147,37 +148,37 @@ def iter_patch_entires(
except ValueError:
continue

match ut_method:
case UtMethod.ExpectedFailure:
# Search first on decorator line, then in the line before
for line in lines[
dec_node.lineno - 1 : dec_node.lineno - 3 : -1
]:
if COMMENT not in line:
continue
reason = "".join(re.findall(rf"{COMMENT}.?(.*)", line))
# If our ut_method has args then,
# we need to search for a constant that contains our `COMMENT`.
# Otherwise we need to search it in the raw source code :/
if ut_method.has_args():
reason = next(
(
node.value
for node in ast.walk(dec_node)
if isinstance(node, ast.Constant)
and isinstance(node.value, str)
and COMMENT in node.value
),
None,
)

Comment on lines +159 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden reason/cond extraction against edge cases.

  • Guard against missing args when has_cond() is true to avoid IndexError.
  • Make the regex more robust and anchored to end-of-line to avoid over-capturing.
  • Optionally accept f-strings (ast.JoinedStr) as reasons.

Apply this diff:

-                if ut_method.has_args():
+                if ut_method.has_args():
                     reason = next(
                         (
-                            node.value
-                            for node in ast.walk(dec_node)
-                            if isinstance(node, ast.Constant)
-                            and isinstance(node.value, str)
-                            and COMMENT in node.value
+                            node.value
+                            for node in ast.walk(dec_node)
+                            if isinstance(node, ast.Constant)
+                            and isinstance(node.value, str)
+                            and COMMENT in node.value
                         ),
                         None,
                     )
@@
-                    if ut_method.has_cond():
-                        cond = ast.unparse(dec_node.args[0])
+                    if ut_method.has_cond() and getattr(dec_node, "args", []):
+                        cond = ast.unparse(dec_node.args[0])
                 else:
                     # Search first on decorator line, then in the line before
                     for line in lines[dec_node.lineno - 1 : dec_node.lineno - 3 : -1]:
-                        if found := re.search(rf"{COMMENT}.?(.*)", line):
-                            reason = found.group()
+                        if found := re.search(rf"{re.escape(COMMENT)}\s*[:;,-]?\s*(.*)$", line):
+                            reason = found.group()
                             break
                     else:
                         # Didn't find our `COMMENT` :)
                         continue

-                reason = reason.removeprefix(COMMENT).strip(";:, ")
+                reason = reason.removeprefix(COMMENT).strip(";:, ")

If you want f-strings support, we can extend the walker to accept ast.JoinedStr and convert via ast.unparse(node).

Also applies to: 174-181, 182-191

🤖 Prompt for AI Agents
In scripts/lib_updater.py around lines 159-173 (and similarly 174-181, 182-191),
the extraction of reason/cond is brittle: it can IndexError when args are
missing, the regex may over-capture, and f-strings (ast.JoinedStr) are not
handled. Fix by first checking ut_method.has_args() and also guarding
ut_method.args list length before indexing; when has_cond() is true verify args
exist before accessing and return None if missing; tighten the regex to anchor
the match to end-of-line (e.g., use a pattern that requires COMMENT followed by
optional whitespace and end-of-line) to avoid over-capturing; and extend the AST
walker to accept ast.JoinedStr nodes (convert them via ast.unparse(node)) in
addition to ast.Constant strings so f-strings are supported; apply the same
guards and parsing changes to the other two blocks at lines 174-181 and 182-191.

# If we didn't find a constant containing <COMMENT>,
# then we didn't put this decorator
if not reason:
continue

if ut_method.has_cond():
cond = ast.unparse(dec_node.args[0])
else:
# Search first on decorator line, then in the line before
for line in lines[dec_node.lineno - 1 : dec_node.lineno - 3 : -1]:
if found := re.search(rf"{COMMENT}.?(.*)", line):
reason = found.group()
break
else:
continue
case _:
reason = next(
(
node.value
for node in ast.walk(dec_node)
if isinstance(node, ast.Constant)
and isinstance(node.value, str)
and COMMENT in node.value
),
None,
)

# If we didn't find a constant containing <COMMENT>,
# then we didn't put this decorator
if not reason:
continue

if ut_method.has_cond():
cond = ast.unparse(dec_node.args[0])
else:
# Didn't find our `COMMENT` :)
continue

reason = reason.removeprefix(COMMENT).strip(";:, ")
spec = PatchSpec(ut_method, cond, reason)
Expand Down Expand Up @@ -242,6 +243,7 @@ def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int,
if not lineno:
print(f"WARNING: {cls_name} does not exist in remote file", file=sys.stderr)
continue

for test_name, specs in tests.items():
patch_lines = "\n".join(f"{INDENT1}{spec.as_decorator()}" for spec in specs)
yield (
Expand Down