Skip to content

Commit

Permalink
AIP-72: Support DAG parsing context in Task SDK (apache#45694)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored and HariGS-DB committed Jan 16, 2025
1 parent 8d6f2c4 commit e155d30
Show file tree
Hide file tree
Showing 15 changed files with 214 additions and 53 deletions.
2 changes: 1 addition & 1 deletion airflow/cli/commands/remote_commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from airflow.jobs.job import Job
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import get_dag, process_subdir, suppress_logs_and_warning
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
from airflow.utils.helpers import ask_yesno
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
Expand Down
2 changes: 1 addition & 1 deletion airflow/task/standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.models.taskinstance import TaskReturnCode
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.settings import CAN_FORK
from airflow.stats import Stats
from airflow.utils.configuration import tmp_configuration_copy
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import IS_WINDOWS, getuser
Expand Down
49 changes: 11 additions & 38 deletions airflow/utils/dag_parsing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from contextlib import contextmanager
from typing import NamedTuple


class AirflowParsingContext(NamedTuple):
"""
Context of parsing for the DAG.
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
execute. You can use these for optimizing dynamically generated DAG files.
"""

dag_id: str | None
task_id: str | None

from __future__ import annotations

_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
import warnings

from airflow.sdk.definitions.context import get_parsing_context

@contextmanager
def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | None = None):
old_dag_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID)
old_task_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID)
if dag_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = dag_id
if task_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = task_id
yield
if old_task_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id
if old_dag_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id
# TODO: Remove this module in Airflow 3.2

warnings.warn(
"Import from the airflow.utils.dag_parsing_context module is deprecated and "
"will be removed in Airflow 3.2. Please import it from 'airflow.sdk'.",
DeprecationWarning,
stacklevel=2,
)

def get_parsing_context() -> AirflowParsingContext:
"""Return the current (DAG) parsing context info."""
return AirflowParsingContext(
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
)
__all__ = ["get_parsing_context"]
2 changes: 1 addition & 1 deletion docs/apache-airflow/howto/dynamic-dag-generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ of the context are set to ``None``.
:emphasize-lines: 4,8,9
from airflow.models.dag import DAG
from airflow.utils.dag_parsing_context import get_parsing_context
from airflow.sdk import get_parsing_context
current_dag_id = get_parsing_context().dag_id
Expand Down
51 changes: 51 additions & 0 deletions newsfragments/45694.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
``get_parsing_context`` have been moved to Task SDK

As part of AIP-72: Task SDK, the function ``get_parsing_context`` has been moved to ``airflow.sdk`` module.
Previously, it was located in ``airflow.utils.dag_parsing_context`` module.

This function is used to optimize DAG parsing during execution when DAGs are generated dynamically.

Before:

.. code-block:: python
from airflow.models.dag import DAG
from airflow.utils.dag_parsing_context import get_parsing_context
current_dag_id = get_parsing_context().dag_id
for thing in list_of_things:
dag_id = f"generated_dag_{thing}"
if current_dag_id is not None and current_dag_id != dag_id:
continue # skip generation of non-selected DAG
with DAG(dag_id=dag_id, ...):
...
After:

.. code-block:: python
from airflow.sdk import get_parsing_context
current_dag_id = get_parsing_context().dag_id
# The rest of the code remains the same
* Types of change

* [x] DAG changes
* [ ] Config changes
* [ ] API changes
* [ ] CLI changes
* [ ] Behaviour changes
* [ ] Plugin changes
* [ ] Dependency change

* Migration rules needed

* ruff

* AIR302

* [ ] ``airflow.utils.dag_parsing_context.get_parsing_context`` -> ``airflow.sdk.get_parsing_context``
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
Expand Down
1 change: 0 additions & 1 deletion scripts/cov/core_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
"airflow/utils/code_utils.py",
"airflow/utils/context.py",
"airflow/utils/dag_cycle_tester.py",
"airflow/utils/dag_parsing_context.py",
"airflow/utils/dates.py",
"airflow/utils/db.py",
"airflow/utils/db_cleanup.py",
Expand Down
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"dag",
"Connection",
"get_current_context",
"get_parsing_context",
"__version__",
]

Expand All @@ -35,7 +36,7 @@
if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import get_current_context
from airflow.sdk.definitions.context import get_current_context, get_parsing_context
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
Expand All @@ -50,6 +51,7 @@
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
"get_current_context": ".definitions.context",
"get_parsing_context": ".definitions.context",
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from contextlib import contextmanager

from airflow.sdk.definitions.context import _AIRFLOW_PARSING_CONTEXT_DAG_ID, _AIRFLOW_PARSING_CONTEXT_TASK_ID


@contextmanager
def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | None = None):
old_dag_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID)
old_task_id = os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID)
if dag_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = dag_id
if task_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = task_id
yield
if old_task_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id
if old_dag_id is not None:
os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id
27 changes: 26 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypedDict
import os
from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict

if TYPE_CHECKING:
# TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x?
Expand Down Expand Up @@ -105,3 +106,27 @@ def my_task():
from airflow.sdk.definitions._internal.contextmanager import _get_current_context

return _get_current_context()


class AirflowParsingContext(NamedTuple):
"""
Context of parsing for the DAG.
If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to
execute. You can use these for optimizing dynamically generated DAG files.
"""

dag_id: str | None
task_id: str | None


_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"


def get_parsing_context() -> AirflowParsingContext:
"""Return the current (DAG) parsing context info."""
return AirflowParsingContext(
dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
)
5 changes: 3 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
Expand Down Expand Up @@ -406,8 +407,8 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
setproctitle(f"airflow worker -- {msg.ti.id}")

log = structlog.get_logger(logger_name="task")
# TODO: set the "magic loop" context vars for parsing
ti = parse(msg)
with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
ti = parse(msg)
log.debug("DAG file parsed", file=msg.dag_rel_path)
else:
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
Expand Down
36 changes: 36 additions & 0 deletions task_sdk/tests/dags/dag_parsing_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime

from airflow.sdk import DAG, BaseOperator, get_parsing_context

DAG_ID = "dag_parsing_context_test"

current_dag_id = get_parsing_context().dag_id

with DAG(
DAG_ID,
start_date=datetime(2024, 2, 21),
schedule=None,
) as the_dag:
BaseOperator(task_id="visible_task")

if current_dag_id == DAG_ID:
# this task will be invisible if the DAG ID is not properly set in the parsing context.
BaseOperator(task_id="conditional_task")
40 changes: 40 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,46 @@ def execute(self, context):
)


def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch, test_dags_dir):
"""
Test that the DAG parsing context is correctly set during the startup process.
This test verifies that the DAG and task IDs are correctly set in the parsing context
when a DAG is started up.
"""
dag_id = "dag_parsing_context_test"
task_id = "conditional_task"

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
dag_rel_path="dag_parsing_context.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
requests_fd=0,
ti_context=make_ti_context(dag_id=dag_id, run_id="c"),
)

mock_supervisor_comms.get_message.return_value = what

# Set the environment variable for DAG bundles
# We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test!
dag_bundle_val = json.dumps(
[
{
"name": "my-bundle",
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
"kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1},
}
]
)

monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val)
ti, _ = startup()

# Presence of `conditional_task` below means DAG ID is properly set in the parsing context!
# Check the dag file for the actual logic!
assert ti.task.dag.task_dict.keys() == {"visible_task", "conditional_task"}


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
2 changes: 1 addition & 1 deletion tests/dags/test_dag_parsing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
from airflow.utils.dag_parsing_context import get_parsing_context
from airflow.sdk.definitions.context import get_parsing_context

DAG_ID = "test_dag_parsing_context"

Expand Down
7 changes: 2 additions & 5 deletions tests/dags/test_parsing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
from airflow.utils.dag_parsing_context import (
from airflow.sdk.definitions.context import (
_AIRFLOW_PARSING_CONTEXT_DAG_ID,
_AIRFLOW_PARSING_CONTEXT_TASK_ID,
Context,
)
from airflow.utils.timezone import datetime

if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context


class DagWithParsingContext(EmptyOperator):
def execute(self, context: Context):
Expand Down

0 comments on commit e155d30

Please sign in to comment.