Skip to content
Next Next commit
Use ast.unparse for decorator generation
  • Loading branch information
ShaharNaveh committed Sep 9, 2025
commit dd5661cdb28c3a12c6e02866bbb964191cd3ed5b
40 changes: 25 additions & 15 deletions scripts/lib_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class UtMethod(enum.StrEnum):
def _generate_next_value_(name, start, count, last_values) -> str:
return name[0].lower() + name[1:]

def has_args(self) -> bool:
return self != self.ExpectedFailure

def has_cond(self) -> bool:
return self.endswith(("If", "Unless"))

Expand Down Expand Up @@ -81,17 +84,23 @@ class PatchSpec(typing.NamedTuple):
cond: str | None = None
reason: str = ""

def fmt(self) -> str:
prefix = f"@unittest.{self.ut_method}"
match self.ut_method:
case UtMethod.ExpectedFailure:
line = f"{prefix} # {COMMENT}; {self.reason}"
case UtMethod.ExpectedFailureIfWindows | UtMethod.Skip:
line = f'{prefix}("{COMMENT}; {self.reason}")'
case UtMethod.SkipIf | UtMethod.SkipUnless | UtMethod.ExpectedFailureIf:
line = f'{prefix}({self.cond}, "{COMMENT}; {self.reason}")'
def as_decorator(self) -> str:
reason = f"{COMMENT}; {self.reason}".strip(" ;")
if not self.ut_method.has_args():
return f"@unittest.{self.ut_method} # {reason}"

args = []
if self.cond:
args.append(ast.parse(self.cond).body[0].value)
args.append(ast.Constant(value=reason))

return line.strip().rstrip(";").strip()
call_node = ast.Call(
func=ast.Attribute(value=ast.Name(id="unittest"), attr=self.ut_method),
args=args,
keywords=[],
)
unparsed = ast.unparse(call_node)
return f"@{unparsed}"


class PatchEntry(typing.NamedTuple):
Expand Down Expand Up @@ -170,9 +179,7 @@ def iter_patch_entires(
if ut_method.has_cond():
cond = ast.unparse(dec_node.args[0])

reason = (
reason.replace(COMMENT, "").strip().lstrip(";").lstrip(":").strip()
)
reason = reason.removeprefix(COMMENT).strip(";:, ")
spec = PatchSpec(ut_method, cond, reason)
yield cls(parent_class, fn_node.name, spec)

Expand Down Expand Up @@ -224,7 +231,10 @@ def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int,
default=fn_node.lineno,
)
indent = " " * fn_node.col_offset
yield (lineno - 1, "\n".join(f"{indent}{spec.fmt()}" for spec in specs))
yield (
lineno - 1,
"\n".join(f"{indent}{spec.as_decorator()}" for spec in specs),
)

# Phase 2: Iterate and mark inhereted tests
for cls_name, tests in patches.items():
Expand All @@ -233,7 +243,7 @@ def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int,
print(f"WARNING: {cls_name} does not exist in remote file", file=sys.stderr)
continue
for test_name, specs in tests.items():
patch_lines = "\n".join(f"{INDENT1}{spec.fmt()}" for spec in specs)
patch_lines = "\n".join(f"{INDENT1}{spec.as_decorator()}" for spec in specs)
yield (
lineno,
f"""
Expand Down