Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improvements to existing rules #13

Merged
merged 4 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Respect scopes in UseJinjaVariableGet
- reorder test cases and add comments
  • Loading branch information
feluelle committed Jun 29, 2022
commit d2ec3425769d1f181c861fdaf5ccc6e208b9469a
35 changes: 23 additions & 12 deletions airflint/rules/use_jinja_variable_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,59 @@
from typing import Any

from refactor import ReplacementAction, Rule
from refactor.context import Ancestry
from refactor.context import Ancestry, Scope


class UseJinjaVariableGet(Rule):
"""Replace `Variable.get("foo")` Calls through the jinja equivalent `{{ var.value.foo }}` if the variable is listed in `template_fields`."""

context_providers = (Ancestry,)
context_providers = (Scope, Ancestry)

def _get_operator_keywords(self, reference: ast.Call) -> list[ast.keyword]:
parent = self.context["ancestry"].get_parent(reference)

if isinstance(parent, ast.Assign):
# Get all operator keywords referencing the variable, Variable.get call was assigned to.
return [
scope = self.context["scope"]
operator_keywords = [
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.keyword)
and isinstance(node.value, ast.Name)
and any(
node.value.id == target.id
and scope.resolve(node.value).can_reach(scope.resolve(target))
for target in parent.targets
if isinstance(target, ast.Name)
)
]
elif isinstance(parent, ast.keyword):
if operator_keywords:
return operator_keywords
raise AssertionError("No operator keywords found. Skipping..")

if isinstance(parent, ast.keyword):
# Direct reference without variable assignment.
return [parent]
else:
raise AssertionError("Not implemented. Skipping..")

raise AssertionError("Not implemented. Skipping..")

def _lookup_template_fields(self, keyword: ast.keyword) -> None:
parent = self.context["ancestry"].get_parent(keyword)

# Find the import node module matching the operator calls name.
assert isinstance(operator_call := parent, ast.Call)
assert isinstance(operator_call.func, ast.Name)
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
)
scope = self.context["scope"].resolve(operator_call.func)
try:
import_node = next(
node
for node in ast.walk(self.context.tree)
if isinstance(node, ast.ImportFrom)
and any(alias.name == operator_call.func.id for alias in node.names)
and scope.can_reach(self.context["scope"].resolve(node))
)
except StopIteration:
raise AssertionError("Could not find import definition. Skipping..")
assert (module_name := import_node.module)

# Try to import the module into python.
Expand Down
63 changes: 55 additions & 8 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
[
(
UseFunctionLevelImports,
# Test that all required imports within functions are being added to functions.
"""
from functools import reduce
from operator import add
Expand Down Expand Up @@ -44,6 +45,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that direct assignment of Variable.get is being transformed to jinja equivalent.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -58,8 +60,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if it cannot import the module.
UseJinjaVariableGet,
# Test that nothing happens if it cannot import the module.
"""
from airflow.models import Variable
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
Expand All @@ -75,6 +77,43 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that nothing happens if the import cannot be reached.
"""
from airflow.models import Variable

def foo():
from airflow.operators.bash import BashOperator

BashOperator(task_id="foo", bash_command=Variable.get("FOO"))
""",
"""
from airflow.models import Variable

def foo():
from airflow.operators.bash import BashOperator

BashOperator(task_id="foo", bash_command=Variable.get("FOO"))
""",
),
(
UseJinjaVariableGet,
# Test that nothing happens if it is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator

BashOperator(task_id="foo", output_encoding=Variable.get("FOO"))
""",
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator

BashOperator(task_id="foo", output_encoding=Variable.get("FOO"))
""",
),
(
UseJinjaVariableGet,
# Test that variable assignment of Variable.get is being transformed to jinja equivalent.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -93,27 +132,30 @@ def other_thing():
""",
),
(
# Test that nothing happens if it is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if the variable cannot be reached.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator

var = Variable.get("FOO")
def foo():
var = Variable.get("FOO")

BashOperator(task_id="foo", output_encoding=var)
BashOperator(task_id="foo", bash_command=var)
""",
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator

var = Variable.get("FOO")
def foo():
var = Variable.get("FOO")

BashOperator(task_id="foo", output_encoding=var)
BashOperator(task_id="foo", bash_command=var)
""",
),
(
UseJinjaVariableGet,
# Test that variable assignment works for multiple keywords.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -132,8 +174,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if at least one keyword is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if at least one keyword is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -152,8 +194,8 @@ def other_thing():
""",
),
(
# Test that nothing happens if variable is being referenced in multiple calls where at least one keyword is not in template_fields.
UseJinjaVariableGet,
# Test that nothing happens if variable is being referenced in multiple Calls where at least one keyword is not in template_fields.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -175,6 +217,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that variable assignment works for multiple Calls.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -196,6 +239,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with deserialize_json works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -211,6 +255,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with default_var works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -226,6 +271,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls with default_var=None works.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand All @@ -241,6 +287,7 @@ def other_thing():
),
(
UseJinjaVariableGet,
# Test that Variable.get calls works with both - deserialize_json and default_var.
"""
from airflow.models import Variable
from airflow.operators.bash import BashOperator
Expand Down