Skip to content
This repository was archived by the owner on Mar 2, 2022. It is now read-only.

Commit 986522d

Browse files
committed
Start of AST based optimization pass - support unary ops
Summary: This adds the first new AST based optimization - folding unary ops into a new constant. This also introduces the usage of the ast.Constant class instead of using dedicated classes (e.g. ast.Num, ast.Str). All constants end up being of this new type. So we also get support for emitting a LOAD_CONST for nodes of these types. Test Plan: ./python -m test.test_compiler
1 parent 90e3449 commit 986522d

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

compiler/optimizer.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import ast
2+
import operator
3+
from ast import Constant, Num, Str, Bytes, Ellipsis, NameConstant, copy_location
4+
5+
from compiler.visitor import ASTRewriter
6+
7+
8+
def is_const(node):
9+
return isinstance(node, (Constant, Num, Str, Bytes, Ellipsis, NameConstant))
10+
11+
12+
def get_const_value(node):
13+
if isinstance(node, (Constant, NameConstant)):
14+
return node.value
15+
elif isinstance(node, Num):
16+
return node.n
17+
elif isinstance(node, (Str, Bytes)):
18+
return node.s
19+
elif isinstance(node, Ellipsis):
20+
return ...
21+
22+
raise TypeError("Bad constant value")
23+
24+
25+
UNARY_OPS = {
26+
ast.Invert: operator.invert,
27+
ast.Not: operator.not_,
28+
ast.UAdd: operator.pos,
29+
ast.USub: operator.neg,
30+
}
31+
INVERSE_OPS = {
32+
ast.Is: ast.IsNot,
33+
ast.IsNot: ast.Is,
34+
ast.In: ast.NotIn,
35+
ast.NotIn: ast.In,
36+
}
37+
38+
39+
class AstOptimizer(ASTRewriter):
40+
def visitUnaryOp(self, node: ast.UnaryOp):
41+
op = self.visit(node.operand)
42+
if is_const(op):
43+
conv = UNARY_OPS[type(node.op)]
44+
val = get_const_value(op)
45+
try:
46+
return copy_location(Constant(conv(val)), node)
47+
except:
48+
pass
49+
elif (
50+
isinstance(node.op, ast.Not)
51+
and isinstance(node.operand, ast.Compare)
52+
and len(node.operand.ops) == 1
53+
):
54+
cmp_op = node.operand.ops[0]
55+
new_op = INVERSE_OPS.get(type(cmp_op))
56+
if new_op is not None:
57+
return self.update_node(node.operand, ops=[new_op()])
58+
59+
return self.update_node(node, operand=op)

compiler/pycodegen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def get_bool_const(node):
9090
if isinstance(node, ast.Name):
9191
if node.id == "__debug__":
9292
return not OPTIMIZE
93+
if isinstance(node, ast.Constant):
94+
return bool(node.value)
9395

9496

9597
def is_constant_false(node):
@@ -2059,6 +2061,10 @@ def compileJumpIf(self, test, next, is_if_true):
20592061
self.emit('POP_JUMP_IF_TRUE' if is_if_true else 'POP_JUMP_IF_FALSE', next)
20602062
return True
20612063

2064+
def visitConstant(self, node: ast.Constant):
2065+
self.update_lineno(node)
2066+
self.emit('LOAD_CONST', node.value)
2067+
20622068

20632069
def get_default_generator():
20642070
if sys.version_info >= (3, 7):

test_compiler/test_py37.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import ast
12
import dis
23
from .common import CompilerTest
34
from compiler.pycodegen import CodeGenerator, Python37CodeGenerator
5+
from compiler.optimizer import AstOptimizer
6+
from compiler.unparse import to_expr
47
from compiler.consts import (
58
CO_OPTIMIZED,
69
CO_NOFREE,
@@ -143,3 +146,21 @@ def test_compile_opt_chained_cmp_op(self):
143146

144147
graph = self.to_graph('assert not a > b > c', CodeGenerator)
145148
self.assertInGraph(graph, 'UNARY_NOT')
149+
150+
def test_ast_optimizer(self):
151+
cases = [
152+
("+1", "1"),
153+
("--1", "1"),
154+
("~1", "-2"),
155+
("not 1", "False"),
156+
("not x is y", "x is not y"),
157+
("not x is not y", "x is y"),
158+
("not x in y", "x not in y"),
159+
("~1.1", "~1.1"),
160+
("+'str'", "+'str'"),
161+
]
162+
for inp, expected in cases:
163+
optimizer = AstOptimizer()
164+
tree = ast.parse(inp)
165+
optimized = to_expr(optimizer.visit(tree).body[0].value)
166+
self.assertEqual(expected, optimized, "Input was: " + inp)

0 commit comments

Comments
 (0)