Skip to content

Commit

Permalink
Fix TaskFlowApi for tasks
Browse files Browse the repository at this point in the history
- respect assignments
- respect op_args and op_kwargs
  • Loading branch information
feluelle committed Apr 19, 2022
1 parent 40d9fa2 commit ad13966
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
28 changes: 23 additions & 5 deletions airflint/rules/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def match(self, node):
keywords=[
keyword
for keyword in python_operator.keywords
if keyword.arg != "python_callable"
if keyword.arg not in ["python_callable", "op_args", "op_kwargs"]
],
),
)
Expand All @@ -81,20 +81,38 @@ class _ReplacePythonOperatorByFunctionCall(Rule):
"""Replace PythonOperator calls by function calls which got decorated with the @task decorator."""

def match(self, node):
assert isinstance(node, ast.Expr)
assert isinstance(node, (ast.Expr, ast.Assign))
assert isinstance(node.value, ast.Call)
assert isinstance(node.value.func, ast.Name)
assert node.value.func.id in ["PythonOperator", "PythonVirtualenvOperator"]
assert isinstance(node.value.func.ctx, ast.Load)

replacement = ast.Call(
replacement = deepcopy(node)

args = next(
(
keyword.value.elts
for keyword in node.value.keywords
if keyword.arg == "op_args"
),
None,
)
kwargs = next(
(
keyword.value.keywords
for keyword in node.value.keywords
if keyword.arg == "op_kwargs"
),
None,
)
replacement.value = ast.Call(
func=next(
keyword.value
for keyword in node.value.keywords
if keyword.arg == "python_callable"
),
args=[],
keywords=[],
args=[args] if args else [],
keywords=[kwargs] if kwargs else [],
)
return ReplacementAction(node, replacement)

Expand Down
42 changes: 42 additions & 0 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,48 @@ def foo():
foo()
""",
),
(
task.EnforceTaskFlowApi,
"""
from airflow.operators.python import PythonOperator
def foo():
pass
task_foo = PythonOperator(task_id="foo", python_callable=foo)
""",
"""
from airflow.operators.python import PythonOperator
from airflow.decorators import task
@task(task_id="foo")
def foo():
pass
task_foo = foo()
""",
),
(
task.EnforceTaskFlowApi,
"""
from airflow.operators.python import PythonOperator
def foo(fizz, bar):
pass
task_foo = PythonOperator(task_id="foo", python_callable=foo, op_kwargs=dict(bar="bar"), op_args=["fizz"])
""",
"""
from airflow.operators.python import PythonOperator
from airflow.decorators import task
@task(task_id="foo")
def foo(fizz, bar):
pass
task_foo = foo("fizz", bar="bar")
""",
),
(
[variable.ReplaceVariableGetByJinja],
"""
Expand Down

0 comments on commit ad13966

Please sign in to comment.