Skip to content

Commit 0c7ac56

Browse files
committed
Refactor datetime_range_filter_factory: coalesce only start_date and end_date filters
1 parent 581ea79 commit 0c7ac56

File tree

2 files changed

+36
-42
lines changed

2 files changed

+36
-42
lines changed

airflow-core/src/airflow/api_fastapi/common/parameters.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from fastapi import Depends, HTTPException, Query, status
3838
from pendulum.parsing.exceptions import ParserError
3939
from pydantic import AfterValidator, BaseModel, NonNegativeInt
40-
from sqlalchemy import Column, and_, case, or_
40+
from sqlalchemy import Column, and_, case, func, or_
4141
from sqlalchemy.inspection import inspect
4242

4343
from airflow.api_fastapi.core_api.base import OrmClause
@@ -493,9 +493,12 @@ def depends_datetime(
493493
lower_bound: datetime | None = Query(alias=f"{filter_name}_gte", default=None),
494494
upper_bound: datetime | None = Query(alias=f"{filter_name}_lte", default=None),
495495
) -> RangeFilter:
496+
attr = getattr(model, attribute_name or filter_name)
497+
if filter_name in ("start_date", "end_date"):
498+
attr = func.coalesce(attr, func.now())
496499
return RangeFilter(
497500
Range(lower_bound=lower_bound, upper_bound=upper_bound),
498-
getattr(model, attribute_name or filter_name),
501+
attr,
499502
)
500503

501504
return depends_datetime

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
from __future__ import annotations
1919

20-
from collections.abc import Sequence
21-
from typing import Annotated, Any, Literal, cast
20+
from typing import Annotated, Literal, cast
2221

2322
import structlog
2423
from fastapi import Depends, HTTPException, Query, status
25-
from sqlalchemy import func, or_, select
24+
from sqlalchemy import or_, select
2625
from sqlalchemy.orm import joinedload
2726
from sqlalchemy.sql.selectable import Select
2827

@@ -50,7 +49,6 @@
5049
float_range_filter_factory,
5150
)
5251
from airflow.api_fastapi.common.router import AirflowRouter
53-
from airflow.api_fastapi.core_api.base import OrmClause
5452
from airflow.api_fastapi.core_api.datamodels.common import BulkBody, BulkResponse
5553
from airflow.api_fastapi.core_api.datamodels.task_instances import (
5654
BulkTaskInstanceBody,
@@ -196,26 +194,21 @@ def get_mapped_task_instances(
196194
error_message = f"Task id {task_id} is not mapped"
197195
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
198196

199-
filters = [
200-
run_after_range,
201-
logical_date_range,
202-
start_date_range,
203-
end_date_range,
204-
update_at_range,
205-
duration_range,
206-
state,
207-
pool,
208-
queue,
209-
executor,
210-
version_number,
211-
]
212-
if state.value is not None and None in state.value:
213-
current_time = func.now()
214-
for f in (start_date_range, end_date_range):
215-
f.attribute = func.coalesce(f.attribute, current_time)
216197
task_instance_select, total_entries = paginated_select(
217198
statement=query,
218-
filters=cast("Sequence[OrmClause[Any]]", filters),
199+
filters=[
200+
run_after_range,
201+
logical_date_range,
202+
start_date_range,
203+
end_date_range,
204+
update_at_range,
205+
duration_range,
206+
state,
207+
pool,
208+
queue,
209+
executor,
210+
version_number,
211+
],
219212
order_by=order_by,
220213
offset=offset,
221214
limit=limit,
@@ -467,26 +460,24 @@ def get_task_instances(
467460
)
468461
query = query.where(TI.run_id == dag_run_id)
469462

470-
filters = [
471-
run_after_range,
472-
logical_date_range,
473-
start_date_range,
474-
end_date_range,
475-
update_at_range,
476-
duration_range,
477-
state,
478-
pool,
479-
queue,
480-
executor,
481-
version_number,
482-
]
483-
if state.value is not None and None in state.value:
484-
current_time = func.now()
485-
for f in (start_date_range, end_date_range):
486-
f.attribute = func.coalesce(f.attribute, current_time)
487463
task_instance_select, total_entries = paginated_select(
488464
statement=query,
489-
filters=cast("Sequence[OrmClause[Any]]", filters),
465+
filters=[
466+
run_after_range,
467+
logical_date_range,
468+
start_date_range,
469+
end_date_range,
470+
update_at_range,
471+
duration_range,
472+
state,
473+
pool,
474+
queue,
475+
executor,
476+
task_id,
477+
task_display_name_pattern,
478+
version_number,
479+
readable_ti_filter,
480+
],
490481
order_by=order_by,
491482
offset=offset,
492483
limit=limit,

0 commit comments

Comments
 (0)