Skip to content

Commit

Permalink
Add kill-switch to plugin v2
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Dec 11, 2024
1 parent 267e96a commit 1511833
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from airflow.models import Variable

from datahub_airflow_plugin.datahub_listener import (
get_airflow_plugin_listener,
hookimpl,
Expand All @@ -17,29 +15,21 @@
def on_task_instance_running(previous_state, task_instance, session):
assert _listener
# This is a bit hacky way to provide a way to disable the listener
if Variable.get("datahub_airflow_plugin_disable_listener", "false").lower() == "true":
return
_listener.on_task_instance_running(previous_state, task_instance, session)

@hookimpl
def on_task_instance_success(previous_state, task_instance, session):
assert _listener
if Variable.get("datahub_airflow_plugin_disable_listener", "false").lower() == "true":
return
_listener.on_task_instance_success(previous_state, task_instance, session)

@hookimpl
def on_task_instance_failed(previous_state, task_instance, session):
assert _listener
if Variable.get("datahub_airflow_plugin_disable_listener", "false").lower() == "true":
return
_listener.on_task_instance_failed(previous_state, task_instance, session)

if hasattr(_listener, "on_dag_run_running"):

@hookimpl
def on_dag_run_running(dag_run, msg):
assert _listener
if Variable.get("datahub_airflow_plugin_disable_listener", "false").lower() == "true":
return
_listener.on_dag_run_running(dag_run, msg)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import airflow
import datahub.emitter.mce_builder as builder
from airflow.models import Variable
from airflow.models.serialized_dag import SerializedDagModel
from datahub.api.entities.datajob import DataJob
from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult
Expand Down Expand Up @@ -50,6 +51,8 @@
entities_to_dataset_urn_list,
)

KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener"

_F = TypeVar("_F", bound=Callable[..., None])
if TYPE_CHECKING:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -371,6 +374,9 @@ def on_task_instance_running(
task_instance: "TaskInstance",
session: "Session", # This will always be QUEUED
) -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

# This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508.
Expand Down Expand Up @@ -481,6 +487,8 @@ def on_task_instance_running(
def on_task_instance_finish(
self, task_instance: "TaskInstance", status: InstanceRunResult
) -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return
dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined]

if self.config.render_templates:
Expand Down Expand Up @@ -540,6 +548,9 @@ def on_task_instance_finish(
def on_task_instance_success(
self, previous_state: None, task_instance: "TaskInstance", session: "Session"
) -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

logger.debug(
Expand All @@ -555,6 +566,9 @@ def on_task_instance_success(
def on_task_instance_failed(
self, previous_state: None, task_instance: "TaskInstance", session: "Session"
) -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

logger.debug(
Expand All @@ -568,6 +582,9 @@ def on_task_instance_failed(
)

def on_dag_start(self, dag_run: "DagRun") -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

dag = dag_run.dag
if not dag:
logger.warning(
Expand Down Expand Up @@ -695,6 +712,9 @@ def on_dag_start(self, dag_run: "DagRun") -> None:
@hookimpl
@run_in_thread
def on_dag_run_running(self, dag_run: "DagRun", msg: str) -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

logger.debug(
Expand All @@ -716,6 +736,9 @@ def on_dag_run_running(self, dag_run: "DagRun", msg: str) -> None:
@hookimpl
@run_in_thread
def on_dataset_created(self, dataset: "Dataset") -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

logger.debug(
Expand All @@ -725,6 +748,9 @@ def on_dataset_created(self, dataset: "Dataset") -> None:
@hookimpl
@run_in_thread
def on_dataset_changed(self, dataset: "Dataset") -> None:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
return

self._set_log_level()

logger.debug(
Expand Down

0 comments on commit 1511833

Please sign in to comment.