Skip to content

Commit 00082ff

Browse files
[v3-0-test] Restore operator extra links for mapped tasks (#50238) (#50244)
closes #49773 (cherry picked from commit 88148ff) Co-authored-by: Amogh Desai <[email protected]>
1 parent 8bacdd0 commit 00082ff

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

airflow-core/src/airflow/models/mappedoperator.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
from typing import TYPE_CHECKING
20+
from functools import cached_property
21+
from typing import TYPE_CHECKING, Any
2122

2223
import attrs
2324
import structlog
2425

26+
from airflow.exceptions import AirflowException
2527
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
2628
from airflow.triggers.base import StartTriggerArgs
2729
from airflow.utils.helpers import prevent_duplicates
2830

2931
if TYPE_CHECKING:
3032
from sqlalchemy.orm.session import Session
3133

34+
from airflow.models import TaskInstance
35+
from airflow.sdk import BaseOperatorLink
3236
from airflow.sdk.definitions.context import Context
3337

3438
log = structlog.get_logger(__name__)
@@ -118,3 +122,54 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St
118122
next_kwargs=next_kwargs,
119123
timeout=timeout,
120124
)
125+
126+
@cached_property
127+
def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
128+
"""Returns dictionary of all extra links for the operator."""
129+
op_extra_links_from_plugin: dict[str, Any] = {}
130+
from airflow import plugins_manager
131+
132+
plugins_manager.initialize_extra_operators_links_plugins()
133+
if plugins_manager.operator_extra_links is None:
134+
raise AirflowException("Can't load operators")
135+
operator_class_type = self.operator_class["task_type"] # type: ignore
136+
for ope in plugins_manager.operator_extra_links:
137+
if ope.operators and any(operator_class_type in cls.__name__ for cls in ope.operators):
138+
op_extra_links_from_plugin.update({ope.name: ope})
139+
140+
operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
141+
# Extra links defined in Plugins overrides operator links defined in operator
142+
operator_extra_links_all.update(op_extra_links_from_plugin)
143+
144+
return operator_extra_links_all
145+
146+
@cached_property
147+
def global_operator_extra_link_dict(self) -> dict[str, Any]:
148+
"""Returns dictionary of all global extra links."""
149+
from airflow import plugins_manager
150+
151+
plugins_manager.initialize_extra_operators_links_plugins()
152+
if plugins_manager.global_operator_extra_links is None:
153+
raise AirflowException("Can't load operators")
154+
return {link.name: link for link in plugins_manager.global_operator_extra_links}
155+
156+
@cached_property
157+
def extra_links(self) -> list[str]:
158+
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
159+
160+
def get_extra_links(self, ti: TaskInstance, name: str) -> str | None:
161+
"""
162+
For an operator, gets the URLs that the ``extra_links`` entry points to.
163+
164+
:meta private:
165+
166+
:raise ValueError: The error message of a ValueError will be passed on through to
167+
the fronted to show up as a tooltip on the disabled link.
168+
:param ti: The TaskInstance for the URL being searched for.
169+
:param name: The name of the link we're looking for the URL for. Should be
170+
one of the options specified in ``extra_links``.
171+
"""
172+
link = self.operator_extra_link_dict.get(name) or self.global_operator_extra_link_dict.get(name)
173+
if not link:
174+
return None
175+
return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type]

airflow-core/tests/unit/serialization/test_dag_serialization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@
9191
from tests_common.test_utils.config import conf_vars
9292
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker
9393
from tests_common.test_utils.mock_operators import (
94+
AirflowLink,
9495
AirflowLink2,
9596
CustomOperator,
97+
GithubLink,
9698
GoogleLink,
9799
MockOperator,
98100
)
@@ -3095,6 +3097,13 @@ def operator_extra_links(self):
30953097
XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
30963098
]
30973099

3100+
mapped_task = deserialized_dag.task_dict["task"]
3101+
assert mapped_task.operator_extra_link_dict == {
3102+
"airflow": XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
3103+
}
3104+
assert mapped_task.global_operator_extra_link_dict == {"airflow": AirflowLink(), "github": GithubLink()}
3105+
assert mapped_task.extra_links == sorted({"airflow", "github"})
3106+
30983107

30993108
def test_handle_v1_serdag():
31003109
v1 = {

0 commit comments

Comments
 (0)