Skip to content

Commit

Permalink
Respect scopes in UseJinjaVariableGet
Browse files Browse the repository at this point in the history
- reorder test cases and add comments
  • Loading branch information
feluelle committed Jun 29, 2022
1 parent c0709b1 commit d2ec342
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 20 deletions.
35 changes: 23 additions & 12 deletions airflint/rules/use_jinja_variable_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,59 @@
from typing import Any

from refactor import ReplacementAction, Rule
from refactor.context import Ancestry
from refactor.context import Ancestry, Scope


class UseJinjaVariableGet(Rule):
"""Replace `Variable.get("foo")` Calls through the jinja equivalent `{{ var.value.foo }}` if the variable is listed in `template_fields`."""

context_providers = (Ancestry,)
context_providers = (Scope, Ancestry)

def _get_operator_keywords(self, reference: ast.Call) -> list[ast.keyword]:
parent = self.context["ancestry"].get_parent(reference)

if isinstance(parent, ast.Assign):
# Get all operator keywords referencing the variable, Variable.get call was assigned to.
return [
scope = self.context["scope"]
operator_keywords = [
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.keyword)
and isinstance(node.value, ast.Name)
and any(
node.value.id == target.id
and scope.resolve(node.value).can_reach(scope.resolve(target))
for target in parent.targets
if isinstance(target, ast.Name)
)
]
elif isinstance(parent, ast.keyword):
if operator_keywords:
return operator_keywords
raise AssertionError("No operator keywords found. Skipping..")

if isinstance(parent, ast.keyword):
# Direct reference without variable assignment.
return [parent]
else:
raise AssertionError("Not implemented. Skipping..")

raise AssertionError("Not implemented. Skipping..")

def _lookup_template_fields(self, keyword: ast.keyword) -> None:
parent = self.context["ancestry"].get_parent(keyword)

# Find the import node module matching the operator calls name.
assert isinstance(operator_call := parent, ast.Call)
assert isinstance(operator_call.func, ast.Name)
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
)
scope = self.context["scope"].resolve(operator_call.func)
try:
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
and scope.can_reach(self.context["scope"].resolve(node))
)
except StopIteration:
raise AssertionError("Could not find import definition. Skipping..")
assert (module_name := import_node.module)

# Try to import the module into python.
Expand Down
63 changes: 55 additions & 8 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
[
(
UseFunctionLevelImports,
# Test that all required imports within functions are being added to functions.
"""
from functools import reduce
from operator import add
Expand Down Expand Up @@ -44,6 +45,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that direct assignment of Variable.get is being transformed to jinja equivalent.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -58,8 +60,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if it cannot import the module.
UseJinjaVariableGet,
# Test that nothing happens if it cannot import the module.
"""
from airflow.models import Variable
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
Expand All @@ -75,6 +77,43 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that nothing happens if the import cannot be reached.
"""
from airflow.models import Variable
def foo():
from airflow.operators.bash import BashOperator
BashOperator(task_id="foo", bash_command=Variable.get("FOO"))
""",
"""
from airflow.models import Variable
def foo():
from airflow.operators.bash import BashOperator
BashOperator(task_id="foo", bash_command=Variable.get("FOO"))
""",
),
(
UseJinjaVariableGet,
# Test that nothing happens if it is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
BashOperator(task_id="foo", output_encoding=Variable.get("FOO"))
""",
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
BashOperator(task_id="foo", output_encoding=Variable.get("FOO"))
""",
),
(
UseJinjaVariableGet,
# Test that variable assignment of Variable.get is being transformed to jinja equivalent.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -93,27 +132,30 @@ def other_thing():
""",
),
(
# Test that nothing happens if it is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if the variable cannot be reached.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
var = Variable.get("FOO")
def foo():
var = Variable.get("FOO")
BashOperator(task_id="foo", output_encoding=var)
BashOperator(task_id="foo", bash_command=var)
""",
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
var = Variable.get("FOO")
def foo():
var = Variable.get("FOO")
BashOperator(task_id="foo", output_encoding=var)
BashOperator(task_id="foo", bash_command=var)
""",
),
(
UseJinjaVariableGet,
# Test that variable assignment works for multiple keywords.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -132,8 +174,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if at least one keyword is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if at least one keyword is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -152,8 +194,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if variable is being referenced in multiple calls where at least one keyword is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if variable is being referenced in multiple Calls where at least one keyword is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -175,6 +217,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that variable assignment works for multiple Calls.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -196,6 +239,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with deserialize_json works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -211,6 +255,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with default_var works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -226,6 +271,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with default_var=None works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -241,6 +287,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls works with both - deserialize_json and default_var.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand Down

0 comments on commit d2ec342

Please sign in to comment.