Skip to content

Commit

Permalink
Fix TaskFlowApi and VariableGet replacements
Browse files Browse the repository at this point in the history
- for task flow api consider only tasks with constants, except for python_callable, op_args and op_kwargs
- do not raise an exception on failed import, skip instead
  • Loading branch information
feluelle committed Apr 21, 2022
1 parent af167cd commit a447e2a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
10 changes: 10 additions & 0 deletions airflint/rules/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def match(self, node):
scope=node_scope,
)
)
assert all(
isinstance(keyword.value, ast.Constant)
for keyword in python_operator.keywords
if keyword.arg not in ["python_callable", "op_args", "op_kwargs"]
)
assert isinstance(python_operator.func, ast.Name)
TASK_MAPPING = {
"PythonOperator": ast.Name(id="task", ctx=ast.Load()),
Expand Down Expand Up @@ -100,6 +105,11 @@ def match(self, node):
for keyword in node.value.keywords
)
)
assert all(
isinstance(keyword.value, ast.Constant)
for keyword in node.value.keywords
if keyword.arg not in ["python_callable", "op_args", "op_kwargs"]
)

args = next(
(
Expand Down
6 changes: 5 additions & 1 deletion airflint/rules/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ def match(self, node):
)
)
and import_node.module
and (file_path := importlib.import_module(import_node.module).__file__)
)
try:
_module = importlib.import_module(import_node.module)
except ImportError:
pass
assert _module and (file_path := _module.__file__)
with open(file_path) as file:
module = ast.parse(file.read())
assert any(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,28 @@ def foo():
PythonOperator(task_id="foo", python_callable=lambda: "foo")
""",
),
(
task.EnforceTaskFlowApi,
"""
from airflow.operators.python import PythonOperator
def foo():
pass
task_id = "foo"
PythonOperator(task_id=task_id, python_callable=foo)
""",
"""
from airflow.operators.python import PythonOperator
from airflow.decorators import task
def foo():
pass
task_id = "foo"
PythonOperator(task_id=task_id, python_callable=foo)
""",
),
(
task.EnforceTaskFlowApi,
"""
Expand Down

0 comments on commit a447e2a

Please sign in to comment.