17
17
18
18
from __future__ import annotations
19
19
20
- from typing import Annotated , Literal , cast
20
+ from collections .abc import Sequence
21
+ from typing import Annotated , Any , Literal , cast
21
22
22
23
import structlog
23
24
from fastapi import Depends , HTTPException , Query , status
49
50
float_range_filter_factory ,
50
51
)
51
52
from airflow .api_fastapi .common .router import AirflowRouter
53
+ from airflow .api_fastapi .core_api .base import OrmClause
52
54
from airflow .api_fastapi .core_api .datamodels .common import BulkBody , BulkResponse
53
55
from airflow .api_fastapi .core_api .datamodels .task_instances import (
54
56
BulkTaskInstanceBody ,
@@ -194,9 +196,10 @@ def get_mapped_task_instances(
194
196
error_message = f"Task id { task_id } is not mapped"
195
197
raise HTTPException (status .HTTP_404_NOT_FOUND , error_message )
196
198
197
- task_instance_select , total_entries = paginated_select (
198
- statement = query ,
199
- filters = [
199
+ if state .value is not None and None in state .value :
200
+ filters = [state , pool , queue , executor , version_number ]
201
+ else :
202
+ filters = [
200
203
run_after_range ,
201
204
logical_date_range ,
202
205
start_date_range ,
@@ -208,7 +211,10 @@ def get_mapped_task_instances(
208
211
queue ,
209
212
executor ,
210
213
version_number ,
211
- ],
214
+ ]
215
+ task_instance_select , total_entries = paginated_select (
216
+ statement = query ,
217
+ filters = cast ("Sequence[OrmClause[Any]]" , filters ),
212
218
order_by = order_by ,
213
219
offset = offset ,
214
220
limit = limit ,
@@ -460,9 +466,19 @@ def get_task_instances(
460
466
)
461
467
query = query .where (TI .run_id == dag_run_id )
462
468
463
- task_instance_select , total_entries = paginated_select (
464
- statement = query ,
465
- filters = [
469
+ if state .value is not None and None in state .value :
470
+ filters = [
471
+ state ,
472
+ pool ,
473
+ queue ,
474
+ executor ,
475
+ task_id ,
476
+ task_display_name_pattern ,
477
+ version_number ,
478
+ readable_ti_filter ,
479
+ ]
480
+ else :
481
+ filters = [
466
482
run_after_range ,
467
483
logical_date_range ,
468
484
start_date_range ,
@@ -477,7 +493,10 @@ def get_task_instances(
477
493
task_display_name_pattern ,
478
494
version_number ,
479
495
readable_ti_filter ,
480
- ],
496
+ ]
497
+ task_instance_select , total_entries = paginated_select (
498
+ statement = query ,
499
+ filters = cast ("Sequence[OrmClause[Any]]" , filters ),
481
500
order_by = order_by ,
482
501
offset = offset ,
483
502
limit = limit ,
0 commit comments