Skip to content

Commit

Permalink
Avoid imports from "providers" (apache#46801)
Browse files Browse the repository at this point in the history
  • Loading branch information
potiuk authored Feb 16, 2025
1 parent 4d5846f commit 4e17ecd
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 57 deletions.
14 changes: 14 additions & 0 deletions contributing-docs/11_provider_packages.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ PROVIDER is the name of the provider package. It might be single directory (goog
cases we have a nested structure one level down (``apache/cassandra``, ``apache/druid``, ``microsoft/winrm``,
``common.io`` for example).

What are the pyproject.toml and provider.yaml files
---------------------------------------------------

On top of the standard ``pyproject.toml`` file where we keep project information,
we have ``provider.yaml`` file in the provider's module of the ``providers``.

Expand All @@ -92,6 +95,9 @@ not modify it - except updating dependencies, as your changes will be lost.
Eventually we might migrate ``provider.yaml`` fully to ``pyproject.toml`` file but it would require custom
``tool.airflow`` toml section to be added to the ``pyproject.toml`` file.

How to manage provider's dependencies
-------------------------------------

If you want to add dependencies to the provider, you should add them to the corresponding ``pyproject.toml``
file.

Expand All @@ -115,6 +121,14 @@ package might be installed when breeze is restarted or by your IDE or by running
or when you run ``pip install -e "./providers"`` or ``pip install -e "./providers/<PROVIDER>"`` for the new
provider structure.

How to reuse code between tests in different providers
------------------------------------------------------

When you develop providers, you might want to reuse some of the code between tests in different providers.
This is possible by placing the code in ``test_utils`` in the ``tests_common`` directory. The ``tests_common``
module is automatically available in the ``sys.path`` when running tests for the providers and you can
import common code from there.

Chicken-egg providers
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

import pytest

import providers.microsoft.azure.tests.unit.microsoft.azure.test_utils
from airflow.exceptions import AirflowException
from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator, uncompress_file

import tests_common.test_utils.file_loading

boto3 = pytest.importorskip("boto3")
moto = pytest.importorskip("moto")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,10 +219,10 @@ def test_execute(self, mock_hiveclihook):

# Upload the file into the Mocked S3 bucket
conn.upload_file(ip_fn, "bucket", self.s3_key + ext)

# file parameter to HiveCliHook.load_file is compared
# against expected file output
providers.microsoft.azure.tests.unit.microsoft.azure.test_utils.load_file.side_effect = (

tests_common.test_utils.file_loading.load_file_from_resources.side_effect = (
lambda *args, **kwargs: self._load_file_side_effect(args, op_fn, ext)
)
# Execute S3ToHiveTransfer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from alembic.migration import MigrationContext
from sqlalchemy import MetaData

import providers.fab.src.airflow.providers.fab as provider_fab
import airflow.providers.fab as provider_fab
from airflow.settings import engine
from airflow.utils.db import (
compare_server_default,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import asyncio
import inspect
from json import JSONDecodeError
from os.path import dirname
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch

Expand All @@ -44,13 +45,12 @@
)
from unit.microsoft.azure.test_utils import (
get_airflow_connection,
load_file,
load_json,
mock_connection,
mock_json_response,
mock_response,
)

from tests_common.test_utils.file_loading import load_file_from_resources, load_json_from_resources
from tests_common.test_utils.providers import get_provider_min_airflow_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -313,7 +313,7 @@ async def test_throw_failed_responses_with_application_json_content_type(self):

class TestResponseHandler:
def test_default_response_handler_when_json(self):
users = load_json("resources", "users.json")
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
response = mock_json_response(200, users)

actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
Expand All @@ -329,7 +329,7 @@ def test_default_response_handler_when_not_json(self):
assert actual == {}

def test_default_response_handler_when_content(self):
users = load_file("resources", "users.json").encode()
users = load_file_from_resources(dirname(__file__), "..", "resources", "users.json").encode()
response = mock_response(200, users)

actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import locale
from base64 import b64encode
from os.path import dirname
from typing import TYPE_CHECKING, Any

import pytest
Expand All @@ -27,8 +28,9 @@
from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
from airflow.triggers.base import TriggerEvent
from unit.microsoft.azure.base import Base
from unit.microsoft.azure.test_utils import load_file, load_json, mock_json_response, mock_response
from unit.microsoft.azure.test_utils import mock_json_response, mock_response

from tests_common.test_utils.file_loading import load_file_from_resources, load_json_from_resources
from tests_common.test_utils.mock_context import mock_context
from tests_common.test_utils.operators.run_deferrable import execute_operator
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
Expand All @@ -44,8 +46,8 @@
class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
def test_execute(self):
users = load_json("resources", "users.json")
next_users = load_json("resources", "next_users.json")
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
next_users = load_json_from_resources(dirname(__file__), "..", "resources", "next_users.json")
response = mock_json_response(200, users, next_users)

with self.patch_hook_and_request_adapter(response):
Expand All @@ -72,7 +74,7 @@ def test_execute(self):

@pytest.mark.db_test
def test_execute_when_do_xcom_push_is_false(self):
users = load_json("resources", "users.json")
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
users.pop("@odata.nextLink")
response = mock_json_response(200, users)

Expand Down Expand Up @@ -134,7 +136,9 @@ def custom_event_handler(context: Context, event: dict[Any, Any] | None = None):

@pytest.mark.db_test
def test_execute_when_response_is_bytes(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
content = load_file_from_resources(
dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", encoding=None
)
base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding())
drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
response = mock_response(200, content)
Expand All @@ -161,7 +165,9 @@ def test_execute_when_response_is_bytes(self):
@pytest.mark.db_test
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter_when_response_is_bytes(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
content = load_file_from_resources(
dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", encoding=None
)
base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding())
drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
response = mock_response(200, content)
Expand Down Expand Up @@ -202,7 +208,7 @@ def test_paginate_without_query_parameters(self):
url="users",
)
context = mock_context(task=operator)
response = load_json("resources", "users.json")
response = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
next_link, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)

assert next_link == response["@odata.nextLink"]
Expand All @@ -216,7 +222,7 @@ def test_paginate_with_context_query_parameters(self):
query_parameters={"$top": 12},
)
context = mock_context(task=operator)
response = load_json("resources", "users.json")
response = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
response["@odata.count"] = 100
url, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@

import json
from datetime import datetime
from os.path import dirname

import pytest

from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
from airflow.triggers.base import TriggerEvent
from unit.microsoft.azure.base import Base
from unit.microsoft.azure.test_utils import load_json, mock_json_response
from unit.microsoft.azure.test_utils import mock_json_response

from tests_common.test_utils.file_loading import load_json_from_resources
from tests_common.test_utils.operators.run_deferrable import execute_operator
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS


class TestMSGraphSensor(Base):
def test_execute(self):
status = load_json("resources", "status.json")
status = load_json_from_resources(dirname(__file__), "..", "resources", "status.json")
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
Expand Down Expand Up @@ -65,7 +67,7 @@ def test_execute(self):

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter(self):
status = load_json("resources", "status.json")
status = load_json_from_resources(dirname(__file__), "..", "resources", "status.json")
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
Expand Down
27 changes: 0 additions & 27 deletions providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

from __future__ import annotations

import json
import re
from json import JSONDecodeError
from os.path import dirname, join
from typing import Any
from unittest import mock
from unittest.mock import MagicMock
Expand Down Expand Up @@ -237,27 +234,3 @@ def mock_response(status_code, content: Any = None, headers: dict | None = None)
response.content = content
response.json.side_effect = JSONDecodeError("", "", 0)
return response


def remove_license_header(content: str) -> str:
"""Remove license header from the given content."""
# Define the pattern to match both block and single-line comments
pattern = r"(/\*.*?\*/)|(--.*?(\r?\n|\r))|(#.*?(\r?\n|\r))"

# Check if there is a license header at the beginning of the file
if re.match(pattern, content, flags=re.DOTALL):
# Use re.DOTALL to allow .* to match newline characters in block comments
return re.sub(pattern, "", content, flags=re.DOTALL).strip()
return content.strip()


def load_json(*args: str):
with open(join(dirname(__file__), *args), encoding="utf-8") as file:
return json.load(file)


def load_file(*args: str, mode="r", encoding="utf-8"):
with open(join(dirname(__file__), *args), mode=mode, encoding=encoding) as file:
if mode == "r":
return remove_license_header(file.read())
return file.read()
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import locale
from base64 import b64decode, b64encode
from datetime import datetime
from os.path import dirname
from unittest.mock import patch
from uuid import uuid4

Expand All @@ -36,18 +37,17 @@
from unit.microsoft.azure.base import Base
from unit.microsoft.azure.test_utils import (
get_airflow_connection,
load_file,
load_json,
mock_json_response,
mock_response,
)

from tests_common.test_utils.file_loading import load_file_from_resources, load_json_from_resources
from tests_common.test_utils.operators.run_deferrable import run_trigger


class TestMSGraphTrigger(Base):
def test_run_when_valid_response(self):
users = load_json("resources", "users.json")
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
response = mock_json_response(200, users)

with self.patch_hook_and_request_adapter(response):
Expand Down Expand Up @@ -83,7 +83,9 @@ def test_run_when_response_cannot_be_converted_to_json(self):
assert actual.payload["message"] == ""

def test_run_when_response_is_bytes(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
content = load_file_from_resources(
dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", encoding=None
)
base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding())
response = mock_response(200, content)

Expand Down Expand Up @@ -138,7 +140,9 @@ def test_template_fields(self):

class TestResponseSerializer:
def test_serialize_when_bytes_then_base64_encoded(self):
response = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
response = load_file_from_resources(
dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", encoding=None
)
content = b64encode(response).decode(locale.getpreferredencoding())

actual = ResponseSerializer().serialize(response)
Expand All @@ -163,15 +167,17 @@ def test_serialize_when_dict_with_uuid_datatime_and_pendulum_then_json(self):
)

def test_deserialize_when_json(self):
response = load_file("resources", "users.json")
response = load_file_from_resources(dirname(__file__), "..", "resources", "users.json")

actual = ResponseSerializer().deserialize(response)

assert isinstance(actual, dict)
assert actual == load_json("resources", "users.json")
assert actual == load_json_from_resources(dirname(__file__), "..", "resources", "users.json")

def test_deserialize_when_base64_encoded_string(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
content = load_file_from_resources(
dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", encoding=None
)
response = b64encode(content).decode(locale.getpreferredencoding())

actual = ResponseSerializer().deserialize(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from os.path import dirname
from unittest import mock

import pytest
Expand All @@ -25,7 +26,8 @@
from airflow.configuration import conf
from airflow.models import Connection
from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
from unit.microsoft.mssql.test_utils import load_file

from tests_common.test_utils.file_loading import load_file_from_resources

try:
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
Expand Down Expand Up @@ -286,7 +288,7 @@ def test_generate_insert_sql(self, get_connection):
],
replace=True,
)
assert sql == load_file("resources", "replace.sql")
assert sql == load_file_from_resources(dirname(__file__), "..", "resources", "replace.sql")

def test_dialect_name(self):
hook = MsSqlHook()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ banned-module-level-imports = ["numpy", "pandas"]
"sqlalchemy.ext.declarative.as_declarative".msg = "Use `sqlalchemy.orm.as_declarative`. Moved in SQLAlchemy 2.0"
"sqlalchemy.ext.declarative.has_inherited_table".msg = "Use `sqlalchemy.orm.has_inherited_table`. Moved in SQLAlchemy 2.0"
"sqlalchemy.ext.declarative.synonym_for".msg = "Use `sqlalchemy.orm.synonym_for`. Moved in SQLAlchemy 2.0"
"providers".msg = "You should not import 'providers' as a Python module. Imports in providers should be done starting from 'src' or `tests' folders, for example 'from airflow.providers.airbyte' or 'from unit.airbyte' or 'from system.airbyte'"

[tool.ruff.lint.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions"]
Expand Down Expand Up @@ -542,7 +543,6 @@ python_files = [
testpaths = [
"tests",
]

asyncio_default_fixture_loop_scope = "function"

# Keep temporary directories (created by `tmp_path`) for 2 recent runs only failed tests.
Expand Down
Loading

0 comments on commit 4e17ecd

Please sign in to comment.