Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improvements to existing rules #13

Merged
merged 4 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ To install it from [PyPI](https://pypi.org/) run:
pip install airflint
```

> **_NOTE:_** It is recommended to install airflint into your existing airflow environment with all your providers included. This way `UseJinjaVariableGet` rule can detect all `template_fields` and airflint works as expected.

Then just call it like this:

![usage](assets/images/usage.png)
Expand All @@ -44,6 +46,9 @@ Alternatively you can add the following repo to your `pre-commit-config.yaml`:
hooks:
- id: airflint
args: ["-a"] # Use -a for replacing inplace
additional_dependencies: # Add all package dependencies you have in your dags, preferable with version spec
- apache-airflow
- apache-airflow-providers-cncf-kubernetes
```

To complete the `UseFunctionlevelImports` rule, please add the `autoflake` hook after the `airflint` hook, as below:
Expand Down
21 changes: 0 additions & 21 deletions airflint/representatives/import_finder.py

This file was deleted.

2 changes: 1 addition & 1 deletion airflint/rules/use_function_level_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def match(self, node: ast.AST) -> Action:
imports = [
definition
for definition in definitions
if isinstance(definition, ast.Import)
if isinstance(definition, (ast.Import, ast.ImportFrom))
]
# And we'll ensure this import is originating from the global scope.
if imports and scope.scope_type is ScopeType.GLOBAL:
Expand Down
144 changes: 114 additions & 30 deletions airflint/rules/use_jinja_variable_get.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,145 @@
import ast
import importlib
from importlib import import_module
from typing import Any

from refactor import ReplacementAction, Rule
from refactor.context import Ancestry, Scope

from airflint.representatives.import_finder import ImportFinder


class UseJinjaVariableGet(Rule):
"""Replace `Variable.get("foo")` Calls through the jinja equivalent `{{ var.value.foo }}` if the variable is listed in `template_fields`."""

context_providers = (Scope, Ancestry, ImportFinder)
context_providers = (Scope, Ancestry)

def match(self, node):
assert (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "Variable"
and node.func.attr == "get"
and isinstance(node.func.ctx, ast.Load)
and isinstance(variable := node.args[0], ast.Constant)
)
assert (
(parents := self.context["ancestry"].get_parents(node))
and isinstance(operator_keyword := next(parents), ast.keyword)
and isinstance(operator_call := next(parents), ast.Call)
and isinstance(operator_call.func, ast.Name)
and (
import_node := self.context["import_finder"].collect(
operator_call.func.id,
scope=self.context["scope"].resolve(operator_call.func),
def _get_operator_keywords(self, reference: ast.Call) -> list[ast.keyword]:
parent = self.context["ancestry"].get_parent(reference)

if isinstance(parent, ast.Assign):
# Get all operator keywords referencing the variable, Variable.get call was assigned to.
scope = self.context["scope"]
operator_keywords = [
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.keyword)
and isinstance(node.value, ast.Name)
and any(
node.value.id == target.id
and scope.resolve(node.value).can_reach(scope.resolve(target))
for target in parent.targets
if isinstance(target, ast.Name)
)
]
if operator_keywords:
return operator_keywords
raise AssertionError("No operator keywords found. Skipping..")

if isinstance(parent, ast.keyword):
# Direct reference without variable assignment.
return [parent]

raise AssertionError(f"Unsupported parent type {type(parent)}. Skipping..")

def _lookup_template_fields(self, keyword: ast.keyword) -> None:
parent = self.context["ancestry"].get_parent(keyword)

# Find the import node module matching the operator calls name.
assert isinstance(operator_call := parent, ast.Call)
assert isinstance(operator_call.func, ast.Name)
scope = self.context["scope"].resolve(operator_call.func)
try:
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
and scope.can_reach(self.context["scope"].resolve(node))
)
and import_node.module
)
except StopIteration:
raise AssertionError("Could not find import definition. Skipping..")
assert (module_name := import_node.module)

# Try to import the module into python.
try:
_module = importlib.import_module(import_node.module)
_module = import_module(module_name)
except ImportError:
return
with open(_module.__file__) as file:
raise AssertionError("Could not import module. Skipping..")
assert (file_path := _module.__file__)

# Parse the ast to check if the keyword is in template_fields.
with open(file_path) as file:
module = ast.parse(file.read())
assert any(
isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "template_fields"
and isinstance(stmt.value, ast.Tuple)
and any(
isinstance(elt, ast.Constant) and elt.value == operator_keyword.arg
isinstance(elt, ast.Constant) and elt.value == keyword.arg
for elt in stmt.value.elts
)
for module_stmt in module.body
if isinstance(module_stmt, ast.ClassDef)
for stmt in module_stmt.body
)

def _get_parameter(
self,
node: ast.Call,
position: int,
name: str,
) -> Any:
if position < len(node.args) and isinstance(
arg := node.args[position],
ast.Constant,
):
return arg.value
return next(
keyword.value.value
for keyword in node.keywords
if keyword.arg == name and isinstance(keyword.value, ast.Constant)
)

def _construct_value(self, node: ast.Call) -> str:
# Read key from Variable.get node.
key = self._get_parameter(node, position=0, name="key")

# Read optional deserialize_json from Variable.get node.
try:
deserialize_json = self._get_parameter(
node,
position=2,
name="deserialize_json",
)
var_type = "json" if deserialize_json else "value"
except StopIteration:
var_type = "value"

# Read optional default_var from Variable.get node and construct the final value.
try:
default_var = self._get_parameter(node, position=1, name="default_var")
if isinstance(default_var, str):
value = f"{{{{ var.{var_type}.get('{key}', '{default_var}') }}}}"
else:
value = f"{{{{ var.{var_type}.get('{key}', {default_var}) }}}}"
except StopIteration:
value = f"{{{{ var.{var_type}.{key} }}}}"

return value

def match(self, node):
assert (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "Variable"
and node.func.attr == "get"
and isinstance(node.func.ctx, ast.Load)
)

for operator_keyword in self._get_operator_keywords(reference=node):
self._lookup_template_fields(keyword=operator_keyword)

return ReplacementAction(
node,
target=ast.Constant(value=f"{{{{ var.value.{variable.value} }}}}"),
target=ast.Constant(value=self._construct_value(node)),
)
Loading