Skip to content

Commit

Permalink
Further improvements to existing rules (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
feluelle authored Jun 29, 2022
1 parent 5f66dc4 commit 5192030
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 114 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ To install it from [PyPI](https://pypi.org/) run:
pip install airflint
```

> **_NOTE:_** It is recommended to install airflint into your existing airflow environment with all your providers included. This way `UseJinjaVariableGet` rule can detect all `template_fields` and airflint works as expected.
Then just call it like this:

![usage](assets/images/usage.png)
Expand All @@ -44,6 +46,9 @@ Alternatively you can add the following repo to your `pre-commit-config.yaml`:
hooks:
- id: airflint
args: ["-a"] # Use -a for replacing inplace
additional_dependencies: # Add all package dependencies you have in your dags, preferable with version spec
- apache-airflow
- apache-airflow-providers-cncf-kubernetes
```
To complete the `UseFunctionlevelImports` rule, please add the `autoflake` hook after the `airflint` hook, as below:
Expand Down
21 changes: 0 additions & 21 deletions airflint/representatives/import_finder.py

This file was deleted.

2 changes: 1 addition & 1 deletion airflint/rules/use_function_level_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def match(self, node: ast.AST) -> Action:
imports = [
definition
for definition in definitions
if isinstance(definition, ast.Import)
if isinstance(definition, (ast.Import, ast.ImportFrom))
]
# And we'll ensure this import is originating from the global scope.
if imports and scope.scope_type is ScopeType.GLOBAL:
Expand Down
144 changes: 114 additions & 30 deletions airflint/rules/use_jinja_variable_get.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,145 @@
import ast
import importlib
from importlib import import_module
from typing import Any

from refactor import ReplacementAction, Rule
from refactor.context import Ancestry, Scope

from airflint.representatives.import_finder import ImportFinder


class UseJinjaVariableGet(Rule):
"""Replace `Variable.get("foo")` Calls through the jinja equivalent `{{ var.value.foo }}` if the variable is listed in `template_fields`."""

context_providers = (Scope, Ancestry, ImportFinder)
context_providers = (Scope, Ancestry)

def match(self, node):
assert (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "Variable"
and node.func.attr == "get"
and isinstance(node.func.ctx, ast.Load)
and isinstance(variable := node.args[0], ast.Constant)
)
assert (
(parents := self.context["ancestry"].get_parents(node))
and isinstance(operator_keyword := next(parents), ast.keyword)
and isinstance(operator_call := next(parents), ast.Call)
and isinstance(operator_call.func, ast.Name)
and (
import_node := self.context["import_finder"].collect(
operator_call.func.id,
scope=self.context["scope"].resolve(operator_call.func),
def _get_operator_keywords(self, reference: ast.Call) -> list[ast.keyword]:
parent = self.context["ancestry"].get_parent(reference)

if isinstance(parent, ast.Assign):
# Get all operator keywords referencing the variable, Variable.get call was assigned to.
scope = self.context["scope"]
operator_keywords = [
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.keyword)
and isinstance(node.value, ast.Name)
and any(
node.value.id == target.id
and scope.resolve(node.value).can_reach(scope.resolve(target))
for target in parent.targets
if isinstance(target, ast.Name)
)
]
if operator_keywords:
return operator_keywords
raise AssertionError("No operator keywords found. Skipping..")

if isinstance(parent, ast.keyword):
# Direct reference without variable assignment.
return [parent]

raise AssertionError(f"Unsupported parent type {type(parent)}. Skipping..")

def _lookup_template_fields(self, keyword: ast.keyword) -> None:
parent = self.context["ancestry"].get_parent(keyword)

# Find the import node module matching the operator calls name.
assert isinstance(operator_call := parent, ast.Call)
assert isinstance(operator_call.func, ast.Name)
scope = self.context["scope"].resolve(operator_call.func)
try:
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
and scope.can_reach(self.context["scope"].resolve(node))
)
and import_node.module
)
except StopIteration:
raise AssertionError("Could not find import definition. Skipping..")
assert (module_name := import_node.module)

# Try to import the module into python.
try:
_module = importlib.import_module(import_node.module)
_module = import_module(module_name)
except ImportError:
return
with open(_module.__file__) as file:
raise AssertionError("Could not import module. Skipping..")
assert (file_path := _module.__file__)

# Parse the ast to check if the keyword is in template_fields.
with open(file_path) as file:
module = ast.parse(file.read())
assert any(
isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "template_fields"
and isinstance(stmt.value, ast.Tuple)
and any(
isinstance(elt, ast.Constant) and elt.value == operator_keyword.arg
isinstance(elt, ast.Constant) and elt.value == keyword.arg
for elt in stmt.value.elts
)
for module_stmt in module.body
if isinstance(module_stmt, ast.ClassDef)
for stmt in module_stmt.body
)

def _get_parameter(
self,
node: ast.Call,
position: int,
name: str,
) -> Any:
if position < len(node.args) and isinstance(
arg := node.args[position],
ast.Constant,
):
return arg.value
return next(
keyword.value.value
for keyword in node.keywords
if keyword.arg == name and isinstance(keyword.value, ast.Constant)
)

def _construct_value(self, node: ast.Call) -> str:
# Read key from Variable.get node.
key = self._get_parameter(node, position=0, name="key")

# Read optional deserialize_json from Variable.get node.
try:
deserialize_json = self._get_parameter(
node,
position=2,
name="deserialize_json",
)
var_type = "json" if deserialize_json else "value"
except StopIteration:
var_type = "value"

# Read optional default_var from Variable.get node and construct the final value.
try:
default_var = self._get_parameter(node, position=1, name="default_var")
if isinstance(default_var, str):
value = f"{{{{ var.{var_type}.get('{key}', '{default_var}') }}}}"
else:
value = f"{{{{ var.{var_type}.get('{key}', {default_var}) }}}}"
except StopIteration:
value = f"{{{{ var.{var_type}.{key} }}}}"

return value

def match(self, node):
assert (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "Variable"
and node.func.attr == "get"
and isinstance(node.func.ctx, ast.Load)
)

for operator_keyword in self._get_operator_keywords(reference=node):
self._lookup_template_fields(keyword=operator_keyword)

return ReplacementAction(
node,
target=ast.Constant(value=f"{{{{ var.value.{variable.value} }}}}"),
target=ast.Constant(value=self._construct_value(node)),
)
Loading

0 comments on commit 5192030

Please sign in to comment.