Skip to content

Commit 0edc4f4

Browse files
authored
Update BaseOperator imports for Airflow 3.0 compatibility (#52503)
1 parent 98fd279 commit 0edc4f4

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from typing import Any
2222

2323
from airflow.exceptions import AirflowException
24-
from airflow.models import BaseOperator
2524
from airflow.providers.apache.kafka.hooks.consume import KafkaConsumerHook
25+
from airflow.providers.apache.kafka.version_compat import BaseOperator
2626
from airflow.utils.module_loading import import_string
2727

2828
VALID_COMMIT_CADENCE = {"never", "end_of_batch", "end_of_operator"}

providers/apache/kafka/src/airflow/providers/apache/kafka/operators/produce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from typing import Any
2323

2424
from airflow.exceptions import AirflowException
25-
from airflow.models import BaseOperator
2625
from airflow.providers.apache.kafka.hooks.produce import KafkaProducerHook
26+
from airflow.providers.apache.kafka.version_compat import BaseOperator
2727
from airflow.utils.module_loading import import_string
2828

2929
local_logger = logging.getLogger("airflow")

providers/apache/kafka/src/airflow/providers/apache/kafka/sensors/kafka.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from collections.abc import Callable, Sequence
2020
from typing import Any
2121

22-
from airflow.models import BaseOperator
2322
from airflow.providers.apache.kafka.triggers.await_message import AwaitMessageTrigger
23+
from airflow.providers.apache.kafka.version_compat import BaseOperator
2424

2525
VALID_COMMIT_CADENCE = {"never", "end_of_batch", "end_of_operator"}
2626

@@ -107,7 +107,7 @@ def execute(self, context) -> Any:
107107

108108
def execute_complete(self, context, event=None):
109109
if self.xcom_push_key:
110-
self.xcom_push(context, key=self.xcom_push_key, value=event)
110+
context["task_instance"].xcom_push(key=self.xcom_push_key, value=event)
111111
return event
112112

113113

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID ADDING UNNECESSARY
19+
# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR PROVIDER THAT DEPENDS
20+
# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT
21+
# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE
22+
#
23+
from __future__ import annotations
24+
25+
26+
def get_base_airflow_version_tuple() -> tuple[int, int, int]:
27+
from packaging.version import Version
28+
29+
from airflow import __version__
30+
31+
airflow_version = Version(__version__)
32+
return airflow_version.major, airflow_version.minor, airflow_version.micro
33+
34+
35+
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
36+
37+
if AIRFLOW_V_3_0_PLUS:
38+
from airflow.sdk import BaseOperator
39+
else:
40+
from airflow.models import BaseOperator
41+
42+
__all__ = [
43+
"AIRFLOW_V_3_0_PLUS",
44+
"BaseOperator",
45+
]

0 commit comments

Comments
 (0)