diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index fae5ae6eb3091..31e1924c18ad8 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json import os import re import subprocess @@ -58,6 +59,9 @@ class BaseK8STest: @pytest.fixture(autouse=True) def base_tests_setup(self, request): + self.set_api_server_base_url_config() + self.rollout_restart_deployment("airflow-api-server") + self.ensure_deployment_health("airflow-api-server") # Replacement for unittests.TestCase.id() self.test_id = f"{request.node.cls.__name__}_{request.node.name}" self.session = self._get_session_with_retries() @@ -204,6 +208,63 @@ def ensure_deployment_health(deployment_name: str, namespace: str = "airflow"): ).decode() assert "successfully rolled out" in deployment_rollout_status + @staticmethod + def rollout_restart_deployment(deployment_name: str, namespace: str = "airflow"): + """Rollout restart the deployment.""" + check_call(["kubectl", "rollout", "restart", "deployment", deployment_name, "-n", namespace]) + + def _parse_airflow_cfg_as_dict(self, airflow_cfg: str) -> dict[str, dict[str, str]]: + """Parse the airflow.cfg file as a dictionary.""" + parsed_airflow_cfg: dict[str, dict[str, str]] = {} + for line in airflow_cfg.splitlines(): + if line.startswith("["): + section = line[1:-1] + parsed_airflow_cfg[section] = {} + elif "=" in line: + key, value = line.split("=", 1) + parsed_airflow_cfg[section][key.strip()] = value.strip() + return parsed_airflow_cfg + + def _parse_airflow_cfg_dict_as_escaped_toml(self, airflow_cfg_dict: dict) -> str: + """Parse the airflow.cfg dictionary as a toml string.""" + airflow_cfg_str = "" + for section, section_dict in airflow_cfg_dict.items(): + airflow_cfg_str += f"[{section}]\n" + for key, value in section_dict.items(): + airflow_cfg_str += f"{key} = {value}\n" + airflow_cfg_str += "\n" + # escape newlines and double quotes + return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"') + + def set_api_server_base_url_config(self): + """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap.""" + configmap_name = "airflow-config" + configmap_key = "airflow.cfg" + original_configmap_json_str = check_output( + ["kubectl", "get", "configmap", configmap_name, "-n", "airflow", "-o", "json"] + ).decode() + original_config_map = json.loads(original_configmap_json_str) + original_airflow_cfg = original_config_map["data"][configmap_key] + # set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` in airflow.cfg + # The airflow.cfg is toml format, so we need to convert it to json + airflow_cfg_dict = self._parse_airflow_cfg_as_dict(original_airflow_cfg) + airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}" + # update the configmap with the new airflow.cfg + check_call( + [ + "kubectl", + "patch", + "configmap", + configmap_name, + "-n", + "airflow", + "--type", + "merge", + "-p", + f'{{"data": {{"{configmap_key}": "{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}', + ] + ) + def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0 state = ""