Skip to content

Commit c453684

Browse files
committed
Fix deferrable mode for SparkKubernetesOperator
1 parent bd4bfa7 commit c453684

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -796,11 +796,13 @@ def _refresh_cached_properties(self):
796796
del self.pod_manager
797797

798798
def execute_async(self, context: Context) -> None:
799-
self.pod_request_obj = self.build_pod_request_obj(context)
800-
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
801-
pod_request_obj=self.pod_request_obj,
802-
context=context,
803-
)
799+
if self.pod_request_obj is None:
800+
self.pod_request_obj = self.build_pod_request_obj(context)
801+
if self.pod is None:
802+
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
803+
pod_request_obj=self.pod_request_obj,
804+
context=context,
805+
)
804806
if self.callbacks:
805807
pod = self.find_pod(self.pod.metadata.namespace, context=context)
806808
for callback in self.callbacks:

providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,23 @@ def find_spark_job(self, context, exclude_checked: bool = True):
254254
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
255255
return pod
256256

257-
def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
257+
def get_or_create_spark_crd(self, context) -> k8s.V1Pod:
258258
if self.reattach_on_restart:
259259
driver_pod = self.find_spark_job(context)
260260
if driver_pod:
261261
return driver_pod
262262

263-
driver_pod, spark_obj_spec = launcher.start_spark_job(
263+
driver_pod, spark_obj_spec = self.launcher.start_spark_job(
264264
image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds
265265
)
266266
return driver_pod
267267

268268
def process_pod_deletion(self, pod, *, reraise=True):
269269
if pod is not None:
270270
if self.delete_on_termination:
271-
self.log.info("Deleting spark job: %s", pod.metadata.name.replace("-driver", ""))
272-
self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
271+
pod_name = pod.metadata.name.replace("-driver", "")
272+
self.log.info("Deleting spark job: %s", pod_name)
273+
self.launcher.delete_spark_job(pod_name)
273274
else:
274275
self.log.info("skipping deleting spark job: %s", pod.metadata.name)
275276

@@ -293,18 +294,22 @@ def client(self) -> CoreV1Api:
293294
def custom_obj_api(self) -> CustomObjectsApi:
294295
return CustomObjectsApi()
295296

296-
def execute(self, context: Context):
297-
self.name = self.create_job_name()
298-
299-
self.log.info("Creating sparkApplication.")
300-
self.launcher = CustomObjectLauncher(
297+
@cached_property
298+
def launcher(self) -> CustomObjectLauncher:
299+
launcher = CustomObjectLauncher(
301300
name=self.name,
302301
namespace=self.namespace,
303302
kube_client=self.client,
304303
custom_obj_api=self.custom_obj_api,
305304
template_body=self.template_body,
306305
)
307-
self.pod = self.get_or_create_spark_crd(self.launcher, context)
306+
return launcher
307+
308+
def execute(self, context: Context):
309+
self.name = self.create_job_name()
310+
311+
self.log.info("Creating sparkApplication.")
312+
self.pod = self.get_or_create_spark_crd(context)
308313
self.pod_request_obj = self.launcher.pod_spec
309314

310315
return super().execute(context=context)

providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from datetime import date
2424
from functools import cached_property
2525
from unittest import mock
26-
from unittest.mock import patch
26+
from unittest.mock import mock_open, patch
2727
from uuid import uuid4
2828

2929
import pendulum
@@ -32,9 +32,11 @@
3232
from kubernetes.client import models as k8s
3333

3434
from airflow import DAG
35+
from airflow.exceptions import TaskDeferred
3536
from airflow.models import Connection, DagRun, TaskInstance
3637
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
3738
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN
39+
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
3840
from airflow.utils import timezone
3941
from airflow.utils.types import DagRunType
4042

@@ -754,6 +756,41 @@ def test_find_custom_pod_labels(
754756
op.find_spark_job(context)
755757
mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector)
756758

759+
@pytest.mark.asyncio
760+
def test_execute_deferrable(
761+
self,
762+
mock_create_namespaced_crd,
763+
mock_get_namespaced_custom_object_status,
764+
mock_cleanup,
765+
mock_create_job_name,
766+
mock_get_kube_client,
767+
mock_create_pod,
768+
mock_await_pod_completion,
769+
mock_fetch_requested_container_logs,
770+
data_file,
771+
mocker,
772+
):
773+
task_name = "test_execute_deferrable"
774+
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())
775+
776+
mock_create_job_name.return_value = task_name
777+
op = SparkKubernetesOperator(
778+
template_spec=job_spec,
779+
kubernetes_conn_id="kubernetes_default_kube_config",
780+
task_id=task_name,
781+
get_logs=True,
782+
deferrable=True,
783+
)
784+
context = create_context(op)
785+
786+
mock_file = mock_open(read_data='{"a": "b"}')
787+
mocker.patch("builtins.open", mock_file)
788+
789+
with pytest.raises(TaskDeferred) as exc:
790+
op.execute(context)
791+
792+
assert isinstance(exc.value.trigger, KubernetesPodTrigger)
793+
757794

758795
@pytest.mark.db_test
759796
def test_template_body_templating(create_task_instance_of_operator, session):

0 commit comments

Comments
 (0)