Skip to content

Commit b3f3649

Browse files
authored
AIP-72: Inline DAG injection for task runner tests (apache#44808)
closes: apache#44805 Dependent on apache#44786 Every time when we port the different TI state handling in the task runner, it is usually followed by an integration test of sorts to test the end to end flow of whether that state is testable or not. For example: 1. For skipped state, we use the DAG https://github.com/apache/airflow/pull/44786/files#diff-cabbddd33130ce1a769412f5fc55dd23e4af4d0fa75f8981689daae769e0680dR1 and we test using the UT in task runner: https://github.com/apache/airflow/pull/44786/files#diff-413c3c59636a3c7b41b8bb822827d18a959778d0b6331532e0db175c829dbfd2R141-R161 2. For deferred state, we use the DAG: https://github.com/apache/airflow/pull/44241/files#diff-2152ed5392424771e27a69173b3c18caae717939719df8f5dbbbdfee5f9efd9bR1 and test it using UT in task runner: https://github.com/apache/airflow/pull/44241/files#diff-413c3c59636a3c7b41b8bb822827d18a959778d0b6331532e0db175c829dbfd2R93-R127 Due to this, when new ti states are added or tests for that matter, it eventually leads to a huge folder with DAGs under `task_sdk/tests/dags` which could soon get ever growing and unmanageable. The solution is in two parts: 1. The first part would be the ability to create dynamic or in line dags which has been implemented using a DAGFactory kind of function: ``` def get_inline_dag(dag_id: str, tasks: BaseOperator) -> DAG: dag = DAG( dag_id=dag_id, default_args={"start_date": timezone.datetime(2024, 12, 3)}, ) setattr(tasks, "dag", dag) return dag ``` This function is capable of accepting `one` task as of now and creating a DAG out of it and returning the DAG object which should suffice our current testing needs, if there is a need, we can extend this function to support more than one tasks and their relationships. Usage: ``` task = PythonOperator( task_id="skip", python_callable=lambda: (_ for _ in ()).throw( AirflowSkipException("This task is being skipped intentionally."), ), ) dag = get_inline_dag("basic_skipped", task) ``` The usage is as simple as creating any task from any operator and passing it down to this function. 2. Mocking the parse function using KGB spy_agency: https://pypi.org/project/kgb/ The idea here is to use a spy agency to substitute out the `parse` function with a mock parser that does a bare minimum of the actual parser. We choose spy_agency over the mock library for two reasons primarily: a) With `spy_agency`, you can mock specific methods or functions without affecting the entire class or module. b) Minimal dispruption and ease of use. 1. Replaced usage of all "actual" dags with in line dags in task runner tests which either do parsing or run. 2. Deleted two DAGs 3. Cannot remove the other two DAGs as they are tied to test_supervisor.py tests which use the DAG path as of now. Can be taken in a follow up if needed. Example: ![image](https://github.com/user-attachments/assets/01baa82a-7b43-4ff1-bc7e-c2fc20cef50d) 1. No need to create any more DAG files for integration tests for task runner, which could be frequent with current development rate for AIP 72. 2. Ability to easily create in line DAGs. Basic DAG ![image](https://github.com/user-attachments/assets/cf7a94b5-6c4c-4103-99a0-32047207a9b2) deferred DAG ![image](https://github.com/user-attachments/assets/328f99d0-4483-48c5-9127-dd7812f47ae0)
1 parent e122b20 commit b3f3649

File tree

3 files changed

+96
-95
lines changed

3 files changed

+96
-95
lines changed

task_sdk/tests/dags/basic_skipped.py

-36
This file was deleted.

task_sdk/tests/dags/basic_templated_dag.py

-37
This file was deleted.

task_sdk/tests/execution_time/test_task_runner.py

+96-22
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,62 @@
2626
import pytest
2727
from uuid6 import uuid7
2828

29+
from airflow.exceptions import AirflowSkipException
2930
from airflow.sdk import DAG, BaseOperator
3031
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
3132
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
32-
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run, startup
33+
from airflow.sdk.execution_time.task_runner import CommsDecoder, RuntimeTaskInstance, parse, run, startup
3334
from airflow.utils import timezone
3435

3536

37+
def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG:
38+
"""Creates an inline dag and returns it based on dag_id and task."""
39+
dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
40+
task.dag = dag
41+
42+
return dag
43+
44+
45+
@pytest.fixture
46+
def mocked_parse(spy_agency):
47+
"""
48+
Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you
49+
want to isolate and test `parse` or `run` logic without having to define a DAG file.
50+
51+
This fixture returns a helper function `set_dag` that:
52+
1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task)
53+
2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task.
54+
3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`.
55+
56+
After adding the fixture in your test function signature, you can use it like this ::
57+
58+
mocked_parse(
59+
StartupDetails(
60+
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
61+
file="",
62+
requests_fd=0,
63+
),
64+
"example_dag_id",
65+
CustomOperator(task_id="hello"),
66+
)
67+
"""
68+
69+
def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance:
70+
dag = get_inline_dag(dag_id, task)
71+
t = dag.task_dict[task.task_id]
72+
ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t)
73+
spy_agency.spy_on(parse, call_fake=lambda _: ti)
74+
return ti
75+
76+
return set_dag
77+
78+
79+
class CustomOperator(BaseOperator):
80+
def execute(self, context):
81+
task_id = context["task_instance"].task_id
82+
print(f"Hello World {task_id}!")
83+
84+
3685
class TestCommsDecoder:
3786
"""Test the communication between the subprocess and the "supervisor"."""
3887

@@ -64,6 +113,7 @@ def test_recv_StartupDetails(self):
64113

65114

66115
def test_parse(test_dags_dir: Path):
116+
"""Test that checks parsing of a basic dag with an un-mocked parse."""
67117
what = StartupDetails(
68118
ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1),
69119
file=str(test_dags_dir / "super_basic.py"),
@@ -78,42 +128,48 @@ def test_parse(test_dags_dir: Path):
78128
assert isinstance(ti.task.dag, DAG)
79129

80130

81-
def test_run_basic(test_dags_dir: Path, time_machine):
131+
def test_run_basic(time_machine, mocked_parse):
82132
"""Test running a basic task."""
83133
what = StartupDetails(
84134
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
85-
file=str(test_dags_dir / "super_basic_run.py"),
135+
file="",
86136
requests_fd=0,
87137
)
88138

89-
ti = parse(what)
90-
91139
instant = timezone.datetime(2024, 12, 3, 10, 0)
92140
time_machine.move_to(instant, tick=False)
93141

94142
with mock.patch(
95143
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
96144
) as mock_supervisor_comms:
145+
ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))
97146
run(ti, log=mock.MagicMock())
98147

99148
mock_supervisor_comms.send_request.assert_called_once_with(
100149
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
101150
)
102151

103152

104-
def test_run_deferred_basic(test_dags_dir: Path, time_machine):
153+
def test_run_deferred_basic(time_machine, mocked_parse):
105154
"""Test that a task can transition to a deferred state."""
106-
what = StartupDetails(
107-
ti=TaskInstance(
108-
id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="c", try_number=1
109-
),
110-
file=str(test_dags_dir / "super_basic_deferred_run.py"),
111-
requests_fd=0,
112-
)
155+
import datetime
156+
157+
from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
113158

114159
# Use the time machine to set the current time
115160
instant = timezone.datetime(2024, 11, 22)
161+
task = DateTimeSensorAsync(
162+
task_id="async",
163+
target_time=str(instant + datetime.timedelta(seconds=3)),
164+
poke_interval=60,
165+
timeout=600,
166+
)
116167
time_machine.move_to(instant, tick=False)
168+
what = StartupDetails(
169+
ti=TaskInstance(id=uuid7(), task_id="async", dag_id="basic_deferred_run", run_id="c", try_number=1),
170+
file="",
171+
requests_fd=0,
172+
)
117173

118174
# Expected DeferTask
119175
expected_defer_task = DeferTask(
@@ -131,22 +187,31 @@ def test_run_deferred_basic(test_dags_dir: Path, time_machine):
131187
with mock.patch(
132188
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
133189
) as mock_supervisor_comms:
134-
ti = parse(what)
190+
ti = mocked_parse(what, "basic_deferred_run", task)
135191
run(ti, log=mock.MagicMock())
136192

137193
# send_request will only be called when the TaskDeferred exception is raised
138194
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY)
139195

140196

141-
def test_run_basic_skipped(test_dags_dir: Path, time_machine):
197+
def test_run_basic_skipped(time_machine, mocked_parse):
142198
"""Test running a basic task that marks itself skipped."""
199+
from airflow.providers.standard.operators.python import PythonOperator
200+
201+
task = PythonOperator(
202+
task_id="skip",
203+
python_callable=lambda: (_ for _ in ()).throw(
204+
AirflowSkipException("This task is being skipped intentionally."),
205+
),
206+
)
207+
143208
what = StartupDetails(
144209
ti=TaskInstance(id=uuid7(), task_id="skip", dag_id="basic_skipped", run_id="c", try_number=1),
145-
file=str(test_dags_dir / "basic_skipped.py"),
210+
file="",
146211
requests_fd=0,
147212
)
148213

149-
ti = parse(what)
214+
ti = mocked_parse(what, "basic_skipped", task)
150215

151216
instant = timezone.datetime(2024, 12, 3, 10, 0)
152217
time_machine.move_to(instant, tick=False)
@@ -161,14 +226,23 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine):
161226
)
162227

163228

164-
def test_startup_basic_templated_dag(test_dags_dir: Path):
165-
"""Test running a basic task."""
229+
def test_startup_basic_templated_dag(mocked_parse):
230+
"""Test running a DAG with templated task."""
231+
from airflow.providers.standard.operators.bash import BashOperator
232+
233+
task = BashOperator(
234+
task_id="templated_task",
235+
bash_command="echo 'Logical date is {{ logical_date }}'",
236+
)
237+
166238
what = StartupDetails(
167-
ti=TaskInstance(id=uuid7(), task_id="task1", dag_id="basic_templated_dag", run_id="c", try_number=1),
168-
file=str(test_dags_dir / "basic_templated_dag.py"),
239+
ti=TaskInstance(
240+
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
241+
),
242+
file="",
169243
requests_fd=0,
170244
)
171-
parse(what)
245+
mocked_parse(what, "basic_templated_dag", task)
172246

173247
with mock.patch(
174248
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True

0 commit comments

Comments
 (0)