diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 8c8fac4de302..4c4a0bed39b9 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -39,7 +39,7 @@ @functools.lru_cache -def _initialize_map() -> dict[str, Callable]: +def initialize_method_map() -> dict[str, Callable]: from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor @@ -147,7 +147,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: if json_rpc != "2.0": return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400) - methods_map = _initialize_map() + methods_map = initialize_method_map() method_name = body.get("method") if method_name not in methods_map: return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index fc2314578fbd..2da451c15537 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -22,6 +22,7 @@ import logging from functools import wraps from typing import Callable, TypeVar +from urllib.parse import urlparse import requests import tenacity @@ -58,6 +59,18 @@ def force_database_direct_access(message: str): if _ENABLE_AIP_44: logger.info("Forcing database direct access. %s", message) + @staticmethod + def force_api_access(api_endpoint: str): + """ + Force using Internal API with provided endpoint. + + All methods decorated with internal_api_call will always be executed remote/via API. + This mode is needed for remote setups/remote executor. + """ + InternalApiConfig._initialized = True + InternalApiConfig._use_internal_api = True + InternalApiConfig._internal_api_endpoint = api_endpoint + @staticmethod def get_use_internal_api(): if not InternalApiConfig._initialized: @@ -77,10 +90,14 @@ def _init_values(): raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") internal_api_endpoint = "" if use_internal_api: - internal_api_url = conf.get("core", "internal_api_url") - internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi" - if not internal_api_endpoint.startswith("http://"): - raise AirflowConfigException("[core]internal_api_url must start with http://") + url_conf = urlparse(conf.get("core", "internal_api_url")) + api_path = url_conf.path + if api_path in ["", "/"]: + # Add the default path if not given in the configuration + api_path = "/internal_api/v1/rpcapi" + if url_conf.scheme not in ["http", "https"]: + raise AirflowConfigException("[core]internal_api_url must start with http:// or https://") + internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}" InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = use_internal_api diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 42027f981c80..426ba061a2d8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -422,7 +422,7 @@ def deref(self, dag: DAG) -> ExpandInput: LogTemplate: LogTemplatePydantic, Dataset: DatasetPydantic, } -_type_to_class: dict[DAT, list] = { +_type_to_class: dict[DAT | str, list] = { DAT.BASE_JOB: [JobPydantic, Job], DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance], DAT.DAG_RUN: [DagRunPydantic, DagRun], @@ -433,6 +433,13 @@ def deref(self, dag: DAG) -> ExpandInput: _class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes} +def add_pydantic_class_type_mapping(attribute_type: str, orm_class, pydantic_class): + _orm_to_model[orm_class] = pydantic_class + _type_to_class[attribute_type] = [pydantic_class, orm_class] + _class_to_type[pydantic_class] = attribute_type + _class_to_type[orm_class] = attribute_type + + class BaseSerialization: """BaseSerialization provides utils for serialization.""" diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index 4c312da3a708..64ea733d39c5 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -75,12 +75,12 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: mock_test_method.reset_mock() mock_test_method.side_effect = None with mock.patch( - "airflow.api_internal.endpoints.rpc_api_endpoint._initialize_map" - ) as mock_initialize_map: - mock_initialize_map.return_value = { + "airflow.api_internal.endpoints.rpc_api_endpoint.initialize_method_map" + ) as mock_initialize_method_map: + mock_initialize_method_map.return_value = { TEST_METHOD_NAME: mock_test_method, } - yield mock_initialize_map + yield mock_initialize_method_map @pytest.mark.parametrize( "input_params, method_result, result_cmp_func, method_params",