Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion airflow-core/src/airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from functools import cached_property
from typing import TYPE_CHECKING, Any

import attrs
import structlog

from airflow.exceptions import AirflowException
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.helpers import prevent_duplicates

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models import TaskInstance
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions.context import Context

log = structlog.get_logger(__name__)
Expand Down Expand Up @@ -118,3 +122,54 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St
next_kwargs=next_kwargs,
timeout=timeout,
)

@cached_property
def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
operator_class_type = self.operator_class["task_type"] # type: ignore
for ope in plugins_manager.operator_extra_links:
if ope.operators and any(operator_class_type in cls.__name__ for cls in ope.operators):
op_extra_links_from_plugin.update({ope.name: ope})

operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)

return operator_extra_links_all

@cached_property
def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links."""
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
return {link.name: link for link in plugins_manager.global_operator_extra_links}

@cached_property
def extra_links(self) -> list[str]:
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, ti: TaskInstance, name: str) -> str | None:
"""
For an operator, gets the URLs that the ``extra_links`` entry points to.

:meta private:

:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param ti: The TaskInstance for the URL being searched for.
:param name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
link = self.operator_extra_link_dict.get(name) or self.global_operator_extra_link_dict.get(name)
if not link:
return None
return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type]
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker
from tests_common.test_utils.mock_operators import (
AirflowLink,
AirflowLink2,
CustomOperator,
GithubLink,
GoogleLink,
MockOperator,
)
Expand Down Expand Up @@ -3095,6 +3097,13 @@ def operator_extra_links(self):
XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
]

mapped_task = deserialized_dag.task_dict["task"]
assert mapped_task.operator_extra_link_dict == {
"airflow": XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
}
assert mapped_task.global_operator_extra_link_dict == {"airflow": AirflowLink(), "github": GithubLink()}
assert mapped_task.extra_links == sorted({"airflow", "github"})


def test_handle_v1_serdag():
v1 = {
Expand Down