Skip to content

Commit 72892d2

Browse files
authored
Fix type import to AbstractOperator (#51773)
1 parent fa574c0 commit 72892d2

File tree

4 files changed

+12
-15
lines changed

4 files changed

+12
-15
lines changed

providers/openlineage/src/airflow/providers/openlineage/utils/utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from airflow import __version__ as AIRFLOW_VERSION
3333

3434
# TODO: move this maybe to Airflow's logic?
35-
from airflow.models import BaseOperator, DagRun, TaskReschedule
35+
from airflow.models import DagRun, TaskReschedule
3636
from airflow.providers.openlineage import (
3737
__version__ as OPENLINEAGE_PROVIDER_VERSION,
3838
conf,
@@ -59,18 +59,14 @@
5959
if not AIRFLOW_V_3_0_PLUS:
6060
from airflow.utils.session import NEW_SESSION, provide_session
6161

62-
try:
63-
from airflow.sdk import BaseOperator as SdkBaseOperator
64-
except ImportError:
65-
SdkBaseOperator = BaseOperator # type: ignore[misc]
66-
6762
if TYPE_CHECKING:
6863
from openlineage.client.event_v2 import Dataset as OpenLineageDataset
6964
from openlineage.client.facet_v2 import RunFacet, processing_engine_run
7065

7166
from airflow.models import TaskInstance
7267
from airflow.providers.common.compat.assets import Asset
7368
from airflow.sdk import DAG
69+
from airflow.sdk.bases.operator import BaseOperator
7470
from airflow.sdk.definitions.mappedoperator import MappedOperator
7571
from airflow.sdk.execution_time.secrets_masker import (
7672
Redactable,
@@ -82,9 +78,10 @@
8278
else:
8379
try:
8480
from airflow.sdk import DAG
81+
from airflow.sdk.bases.operator import BaseOperator
8582
from airflow.sdk.definitions.mappedoperator import MappedOperator
8683
except ImportError:
87-
from airflow.models import DAG, MappedOperator
84+
from airflow.models import DAG, BaseOperator, MappedOperator
8885

8986
try:
9087
from airflow.providers.common.compat.assets import Asset
@@ -119,7 +116,7 @@ def try_import_from_string(string: str) -> Any:
119116
return import_string(string)
120117

121118

122-
def get_operator_class(task: BaseOperator | SdkBaseOperator) -> type:
119+
def get_operator_class(task: BaseOperator) -> type:
123120
if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"):
124121
return task.operator_class
125122
return task.__class__
@@ -203,25 +200,25 @@ def get_user_provided_run_facets(ti: TaskInstance, ti_state: TaskInstanceState)
203200
return custom_facets
204201

205202

206-
def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> str:
203+
def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str:
207204
if isinstance(operator, (MappedOperator, SerializedBaseOperator)):
208205
# as in airflow.api_connexion.schemas.common_schema.ClassReferenceSchema
209206
return operator._task_module + "." + operator._task_type # type: ignore
210207
op_class = get_operator_class(operator)
211208
return op_class.__module__ + "." + op_class.__name__
212209

213210

214-
def is_operator_disabled(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> bool:
211+
def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool:
215212
return get_fully_qualified_class_name(operator) in conf.disabled_operators()
216213

217214

218-
def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator | SdkBaseOperator) -> bool:
215+
def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> bool:
219216
"""If selective enable is active check if DAG or Task is enabled to emit events."""
220217
if not conf.selective_enable():
221218
return True
222219
if isinstance(obj, DAG):
223220
return is_dag_lineage_enabled(obj)
224-
if isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)):
221+
if isinstance(obj, (BaseOperator, MappedOperator)):
225222
return is_task_lineage_enabled(obj)
226223
raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects")
227224

task-sdk/src/airflow/sdk/definitions/_internal/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from airflow.sdk.definitions._internal.mixins import DependencyMixin
2828

2929
if TYPE_CHECKING:
30-
from airflow.sdk.definitions.abstractoperator import Operator
30+
from airflow.sdk.definitions._internal.abstractoperator import Operator
3131
from airflow.sdk.definitions.dag import DAG
3232
from airflow.sdk.definitions.edges import EdgeModifier
3333
from airflow.sdk.definitions.taskgroup import TaskGroup

task-sdk/src/airflow/sdk/definitions/dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
from pendulum.tz.timezone import FixedTimezone, Timezone
7676

77-
from airflow.sdk.definitions.abstractoperator import Operator
77+
from airflow.sdk.definitions._internal.abstractoperator import Operator
7878
from airflow.sdk.definitions.decorators import TaskDecoratorCollection
7979
from airflow.sdk.definitions.taskgroup import TaskGroup
8080
from airflow.typing_compat import Self

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]:
663663

664664
def task_group_to_dict(task_item_or_group):
665665
"""Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
666-
from airflow.sdk.definitions.abstractoperator import AbstractOperator
666+
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
667667
from airflow.sdk.definitions.mappedoperator import MappedOperator
668668
from airflow.sensors.base import BaseSensorOperator
669669

0 commit comments

Comments
 (0)