26
26
import pytest
27
27
from uuid6 import uuid7
28
28
29
+ from airflow .exceptions import AirflowSkipException
29
30
from airflow .sdk import DAG , BaseOperator
30
31
from airflow .sdk .api .datamodels ._generated import TaskInstance , TerminalTIState
31
32
from airflow .sdk .execution_time .comms import DeferTask , SetRenderedFields , StartupDetails , TaskState
32
- from airflow .sdk .execution_time .task_runner import CommsDecoder , parse , run , startup
33
+ from airflow .sdk .execution_time .task_runner import CommsDecoder , RuntimeTaskInstance , parse , run , startup
33
34
from airflow .utils import timezone
34
35
35
36
37
+ def get_inline_dag (dag_id : str , task : BaseOperator ) -> DAG :
38
+ """Creates an inline dag and returns it based on dag_id and task."""
39
+ dag = DAG (dag_id = dag_id , start_date = timezone .datetime (2024 , 12 , 3 ))
40
+ task .dag = dag
41
+
42
+ return dag
43
+
44
+
45
+ @pytest .fixture
46
+ def mocked_parse (spy_agency ):
47
+ """
48
+ Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you
49
+ want to isolate and test `parse` or `run` logic without having to define a DAG file.
50
+
51
+ This fixture returns a helper function `set_dag` that:
52
+ 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task)
53
+ 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task.
54
+ 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`.
55
+
56
+ After adding the fixture in your test function signature, you can use it like this ::
57
+
58
+ mocked_parse(
59
+ StartupDetails(
60
+ ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
61
+ file="",
62
+ requests_fd=0,
63
+ ),
64
+ "example_dag_id",
65
+ CustomOperator(task_id="hello"),
66
+ )
67
+ """
68
+
69
+ def set_dag (what : StartupDetails , dag_id : str , task : BaseOperator ) -> RuntimeTaskInstance :
70
+ dag = get_inline_dag (dag_id , task )
71
+ t = dag .task_dict [task .task_id ]
72
+ ti = RuntimeTaskInstance .model_construct (** what .ti .model_dump (exclude_unset = True ), task = t )
73
+ spy_agency .spy_on (parse , call_fake = lambda _ : ti )
74
+ return ti
75
+
76
+ return set_dag
77
+
78
+
79
+ class CustomOperator (BaseOperator ):
80
+ def execute (self , context ):
81
+ task_id = context ["task_instance" ].task_id
82
+ print (f"Hello World { task_id } !" )
83
+
84
+
36
85
class TestCommsDecoder :
37
86
"""Test the communication between the subprocess and the "supervisor"."""
38
87
@@ -64,6 +113,7 @@ def test_recv_StartupDetails(self):
64
113
65
114
66
115
def test_parse (test_dags_dir : Path ):
116
+ """Test that checks parsing of a basic dag with an un-mocked parse."""
67
117
what = StartupDetails (
68
118
ti = TaskInstance (id = uuid7 (), task_id = "a" , dag_id = "super_basic" , run_id = "c" , try_number = 1 ),
69
119
file = str (test_dags_dir / "super_basic.py" ),
@@ -78,42 +128,48 @@ def test_parse(test_dags_dir: Path):
78
128
assert isinstance (ti .task .dag , DAG )
79
129
80
130
81
- def test_run_basic (test_dags_dir : Path , time_machine ):
131
+ def test_run_basic (time_machine , mocked_parse ):
82
132
"""Test running a basic task."""
83
133
what = StartupDetails (
84
134
ti = TaskInstance (id = uuid7 (), task_id = "hello" , dag_id = "super_basic_run" , run_id = "c" , try_number = 1 ),
85
- file = str ( test_dags_dir / "super_basic_run.py" ) ,
135
+ file = "" ,
86
136
requests_fd = 0 ,
87
137
)
88
138
89
- ti = parse (what )
90
-
91
139
instant = timezone .datetime (2024 , 12 , 3 , 10 , 0 )
92
140
time_machine .move_to (instant , tick = False )
93
141
94
142
with mock .patch (
95
143
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS" , create = True
96
144
) as mock_supervisor_comms :
145
+ ti = mocked_parse (what , "super_basic_run" , CustomOperator (task_id = "hello" ))
97
146
run (ti , log = mock .MagicMock ())
98
147
99
148
mock_supervisor_comms .send_request .assert_called_once_with (
100
149
msg = TaskState (state = TerminalTIState .SUCCESS , end_date = instant ), log = mock .ANY
101
150
)
102
151
103
152
104
- def test_run_deferred_basic (test_dags_dir : Path , time_machine ):
153
+ def test_run_deferred_basic (time_machine , mocked_parse ):
105
154
"""Test that a task can transition to a deferred state."""
106
- what = StartupDetails (
107
- ti = TaskInstance (
108
- id = uuid7 (), task_id = "async" , dag_id = "super_basic_deferred_run" , run_id = "c" , try_number = 1
109
- ),
110
- file = str (test_dags_dir / "super_basic_deferred_run.py" ),
111
- requests_fd = 0 ,
112
- )
155
+ import datetime
156
+
157
+ from airflow .providers .standard .sensors .date_time import DateTimeSensorAsync
113
158
114
159
# Use the time machine to set the current time
115
160
instant = timezone .datetime (2024 , 11 , 22 )
161
+ task = DateTimeSensorAsync (
162
+ task_id = "async" ,
163
+ target_time = str (instant + datetime .timedelta (seconds = 3 )),
164
+ poke_interval = 60 ,
165
+ timeout = 600 ,
166
+ )
116
167
time_machine .move_to (instant , tick = False )
168
+ what = StartupDetails (
169
+ ti = TaskInstance (id = uuid7 (), task_id = "async" , dag_id = "basic_deferred_run" , run_id = "c" , try_number = 1 ),
170
+ file = "" ,
171
+ requests_fd = 0 ,
172
+ )
117
173
118
174
# Expected DeferTask
119
175
expected_defer_task = DeferTask (
@@ -131,22 +187,31 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine):
131
187
with mock .patch (
132
188
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS" , create = True
133
189
) as mock_supervisor_comms :
134
- ti = parse (what )
190
+ ti = mocked_parse (what , "basic_deferred_run" , task )
135
191
run (ti , log = mock .MagicMock ())
136
192
137
193
# send_request will only be called when the TaskDeferred exception is raised
138
194
mock_supervisor_comms .send_request .assert_called_once_with (msg = expected_defer_task , log = mock .ANY )
139
195
140
196
141
- def test_run_basic_skipped (test_dags_dir : Path , time_machine ):
197
+ def test_run_basic_skipped (time_machine , mocked_parse ):
142
198
"""Test running a basic task that marks itself skipped."""
199
+ from airflow .providers .standard .operators .python import PythonOperator
200
+
201
+ task = PythonOperator (
202
+ task_id = "skip" ,
203
+ python_callable = lambda : (_ for _ in ()).throw (
204
+ AirflowSkipException ("This task is being skipped intentionally." ),
205
+ ),
206
+ )
207
+
143
208
what = StartupDetails (
144
209
ti = TaskInstance (id = uuid7 (), task_id = "skip" , dag_id = "basic_skipped" , run_id = "c" , try_number = 1 ),
145
- file = str ( test_dags_dir / "basic_skipped.py" ) ,
210
+ file = "" ,
146
211
requests_fd = 0 ,
147
212
)
148
213
149
- ti = parse (what )
214
+ ti = mocked_parse (what , "basic_skipped" , task )
150
215
151
216
instant = timezone .datetime (2024 , 12 , 3 , 10 , 0 )
152
217
time_machine .move_to (instant , tick = False )
@@ -161,14 +226,23 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine):
161
226
)
162
227
163
228
164
- def test_startup_basic_templated_dag (test_dags_dir : Path ):
165
- """Test running a basic task."""
229
+ def test_startup_basic_templated_dag (mocked_parse ):
230
+ """Test running a DAG with templated task."""
231
+ from airflow .providers .standard .operators .bash import BashOperator
232
+
233
+ task = BashOperator (
234
+ task_id = "templated_task" ,
235
+ bash_command = "echo 'Logical date is {{ logical_date }}'" ,
236
+ )
237
+
166
238
what = StartupDetails (
167
- ti = TaskInstance (id = uuid7 (), task_id = "task1" , dag_id = "basic_templated_dag" , run_id = "c" , try_number = 1 ),
168
- file = str (test_dags_dir / "basic_templated_dag.py" ),
239
+ ti = TaskInstance (
240
+ id = uuid7 (), task_id = "templated_task" , dag_id = "basic_templated_dag" , run_id = "c" , try_number = 1
241
+ ),
242
+ file = "" ,
169
243
requests_fd = 0 ,
170
244
)
171
- parse (what )
245
+ mocked_parse (what , "basic_templated_dag" , task )
172
246
173
247
with mock .patch (
174
248
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS" , create = True
0 commit comments