Skip to content
This repository was archived by the owner on Mar 2, 2022. It is now read-only.

Commit 90e3449

Browse files
committed
Adds ast rewriter
Summary: Python 3.7 starts doing various optimizations at the AST level, where it actually makes sense. To do this we need the ability to re-write the ASTs. This adds an ASTRewriter class which supports having the visit() methods return the new nodes to be re-written. Child nodes which aren't modified remain in the tree, and the spine of the AST is rewritten as we walk back up the AST. Even though the Python ASTs are mutable this supports performing the re-writes w/o modifying the AST in place. Test Plan: ./python -m test.test_compiler
1 parent 836df8c commit 90e3449

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

compiler/visitor.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import print_function
22

33
import ast
4+
from ast import AST, copy_location
45
from typing import Any, List
56

67
# XXX should probably rename ASTVisitor to ASTWalker
@@ -63,6 +64,75 @@ def visit(self, node, *args):
6364
return meth(node, *args)
6465

6566

67+
class ASTRewriter(ASTVisitor):
68+
"""performs rewrites on the AST, rewriting parent nodes when child nodes
69+
are replaced."""
70+
71+
@staticmethod
72+
def update_node(node: AST, **replacement: Any) -> AST:
73+
res = node
74+
for name, val in replacement.items():
75+
existing = getattr(res, name)
76+
if existing is val:
77+
continue
78+
79+
if node is res:
80+
res = ASTRewriter.clone_node(node)
81+
82+
setattr(res, name, val)
83+
return res
84+
85+
@staticmethod
86+
def clone_node(node: AST) -> AST:
87+
attrs = []
88+
for name in node._fields:
89+
attr = getattr(node, name, None)
90+
if isinstance(attr, list):
91+
attr = list(attr)
92+
attrs.append(attr)
93+
94+
new = type(node)(*attrs)
95+
return copy_location(new, node)
96+
97+
def walk_list(self, old_values: List[AST]) -> List[AST]:
98+
new_values = []
99+
changed = False
100+
for value in old_values:
101+
if isinstance(value, AST):
102+
new_value = self.visit(value)
103+
changed |= new_value is not value
104+
if new_value is None:
105+
continue
106+
elif not isinstance(new_value, AST):
107+
new_values.extend(new_value)
108+
continue
109+
value = new_value
110+
111+
new_values.append(value)
112+
return new_values if changed else old_values
113+
114+
def generic_visit(self, node: AST, *args) -> AST:
115+
if isinstance(node, list):
116+
return self.walk_list(node)
117+
118+
ret_node = node
119+
for field, old_value in ast.iter_fields(node):
120+
if not isinstance(old_value, (AST, list)):
121+
continue
122+
123+
new_node = self.visit(old_value)
124+
assert ( # noqa: IG01
125+
new_node is not None
126+
), "can't remove AST nodes that aren't part of a list"
127+
if new_node is not old_value:
128+
if ret_node is node:
129+
ret_node = self.clone_node(node)
130+
131+
setattr(ret_node, field, new_node)
132+
133+
return ret_node
134+
135+
66136
class ExampleASTVisitor(ASTVisitor):
67137
"""Prints examples of the nodes that aren't visited
68138

test_compiler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .test_sbs_stdlib import SbsCompileTests
99
from .test_symbols import SymbolVisitorTests
1010
from .test_unparse import UnparseTests
11+
from .test_visitor import VisitorTests

test_compiler/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .test_errors import ErrorTests, ErrorTestsBuiltin
1010
from .test_symbols import SymbolVisitorTests
1111
from .test_unparse import UnparseTests
12+
from .test_visitor import VisitorTests
1213

1314
if __name__ == "__main__":
1415
unittest.main()

test_compiler/test_visitor.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import ast
2+
from compiler.visitor import ASTRewriter
3+
from compiler.unparse import to_expr
4+
from unittest import TestCase
5+
6+
7+
class VisitorTests(TestCase):
8+
def test_rewrite_node(self):
9+
class TestRewriter(ASTRewriter):
10+
def visitName(self, node):
11+
if isinstance(node, ast.Name) and node.id == "z":
12+
return ast.Name("foo", ast.Load())
13+
return node
14+
15+
tree = ast.parse("x + y + z").body[0].value
16+
17+
rewriter = TestRewriter()
18+
new_tree = rewriter.visit(tree)
19+
20+
self.assertIsNotNone(new_tree)
21+
self.assertNotEqual(new_tree, tree, "tree should be rewritten")
22+
self.assertEqual(to_expr(new_tree), "x + y + foo")
23+
self.assertEqual(tree.left, new_tree.left, "Unchanged nodes should be the same")
24+
25+
def test_rewrite_stmt(self):
26+
class TestRewriter(ASTRewriter):
27+
def visitAssign(self, node):
28+
return ast.AnnAssign(
29+
ast.Name("foo", ast.Store), ast.Str("foo"), ast.Num(1), True
30+
)
31+
32+
tree = ast.parse("x = 1\nf()\n")
33+
34+
rewriter = TestRewriter()
35+
new_tree = rewriter.visit(tree)
36+
37+
self.assertIsNotNone(new_tree)
38+
self.assertNotEqual(new_tree, tree, "tree should be rewritten")
39+
self.assertNotEqual(tree.body[0], new_tree.body[0])
40+
self.assertEqual(tree.body[1], new_tree.body[1])
41+
42+
def test_remove_node(self):
43+
class TestRewriter(ASTRewriter):
44+
def visitAssign(self, node):
45+
return None
46+
47+
tree = ast.parse("x = 1\nf()\n")
48+
49+
rewriter = TestRewriter()
50+
new_tree = rewriter.visit(tree)
51+
52+
self.assertIsNotNone(new_tree)
53+
self.assertEqual(tree.body[1], new_tree.body[0])
54+
self.assertEqual(len(new_tree.body), 1)
55+
56+
def test_change_child_and_list(self):
57+
class TestRewriter(ASTRewriter):
58+
def visitarguments(self, node: ast.arguments):
59+
node = self.clone_node(node)
60+
node.vararg = ast.arg("args", None)
61+
return node
62+
63+
def visitAssign(self, node):
64+
return ast.Pass()
65+
66+
tree = ast.parse("def f():\n x = 1")
67+
68+
rewriter = TestRewriter()
69+
new_tree = rewriter.visit(tree)
70+
func = new_tree.body[0]
71+
self.assertEqual(type(func.body[0]), ast.Pass)
72+
self.assertIsNotNone(func.args.vararg)
73+
74+
def test_change_list_and_child(self):
75+
class TestRewriter(ASTRewriter):
76+
def visitStr(self, node: ast.Str):
77+
return ast.Str("bar")
78+
79+
def visitAssign(self, node):
80+
return ast.Pass()
81+
82+
tree = ast.parse("def f() -> 'foo':\n x = 1")
83+
84+
rewriter = TestRewriter()
85+
new_tree = rewriter.visit(tree)
86+
func = new_tree.body[0]
87+
self.assertIsNotNone(func.returns)
88+
self.assertEqual(type(func.body[0]), ast.Pass)

0 commit comments

Comments
 (0)