Skip to content

Commit

Permalink
Have integration points for AIP-69 in Internal API (#40901)
Browse files Browse the repository at this point in the history
* Have integration points for AIP-69 in Internal API

* Allow adding custom classes into serialization mapping

* Have integration points for AIP-69 in Internal API

* Fix pytest, reverse validation check for scheme

* Have integration points for AIP-69 in Internal API

* Review feedback
  • Loading branch information
jscheffl authored Jul 21, 2024
1 parent 3a00909 commit 6745cb8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
4 changes: 2 additions & 2 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


@functools.lru_cache
def _initialize_map() -> dict[str, Callable]:
def initialize_method_map() -> dict[str, Callable]:
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
Expand Down Expand Up @@ -147,7 +147,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
if json_rpc != "2.0":
return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400)

methods_map = _initialize_map()
methods_map = initialize_method_map()
method_name = body.get("method")
if method_name not in methods_map:
return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400)
Expand Down
25 changes: 21 additions & 4 deletions airflow/api_internal/internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from functools import wraps
from typing import Callable, TypeVar
from urllib.parse import urlparse

import requests
import tenacity
Expand Down Expand Up @@ -58,6 +59,18 @@ def force_database_direct_access(message: str):
if _ENABLE_AIP_44:
logger.info("Forcing database direct access. %s", message)

@staticmethod
def force_api_access(api_endpoint: str):
"""
Force using Internal API with provided endpoint.
All methods decorated with internal_api_call will always be executed remote/via API.
This mode is needed for remote setups/remote executor.
"""
InternalApiConfig._initialized = True
InternalApiConfig._use_internal_api = True
InternalApiConfig._internal_api_endpoint = api_endpoint

@staticmethod
def get_use_internal_api():
if not InternalApiConfig._initialized:
Expand All @@ -77,10 +90,14 @@ def _init_values():
raise RuntimeError("The AIP_44 is not enabled so you cannot use it.")
internal_api_endpoint = ""
if use_internal_api:
internal_api_url = conf.get("core", "internal_api_url")
internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi"
if not internal_api_endpoint.startswith("http://"):
raise AirflowConfigException("[core]internal_api_url must start with http://")
url_conf = urlparse(conf.get("core", "internal_api_url"))
api_path = url_conf.path
if api_path in ["", "/"]:
# Add the default path if not given in the configuration
api_path = "/internal_api/v1/rpcapi"
if url_conf.scheme not in ["http", "https"]:
raise AirflowConfigException("[core]internal_api_url must start with http:// or https://")
internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}"

InternalApiConfig._initialized = True
InternalApiConfig._use_internal_api = use_internal_api
Expand Down
9 changes: 8 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def deref(self, dag: DAG) -> ExpandInput:
LogTemplate: LogTemplatePydantic,
Dataset: DatasetPydantic,
}
_type_to_class: dict[DAT, list] = {
_type_to_class: dict[DAT | str, list] = {
DAT.BASE_JOB: [JobPydantic, Job],
DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
DAT.DAG_RUN: [DagRunPydantic, DagRun],
Expand All @@ -433,6 +433,13 @@ def deref(self, dag: DAG) -> ExpandInput:
_class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes}


def add_pydantic_class_type_mapping(attribute_type: str, orm_class, pydantic_class):
_orm_to_model[orm_class] = pydantic_class
_type_to_class[attribute_type] = [pydantic_class, orm_class]
_class_to_type[pydantic_class] = attribute_type
_class_to_type[orm_class] = attribute_type


class BaseSerialization:
"""BaseSerialization provides utils for serialization."""

Expand Down
8 changes: 4 additions & 4 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
mock_test_method.reset_mock()
mock_test_method.side_effect = None
with mock.patch(
"airflow.api_internal.endpoints.rpc_api_endpoint._initialize_map"
) as mock_initialize_map:
mock_initialize_map.return_value = {
"airflow.api_internal.endpoints.rpc_api_endpoint.initialize_method_map"
) as mock_initialize_method_map:
mock_initialize_method_map.return_value = {
TEST_METHOD_NAME: mock_test_method,
}
yield mock_initialize_map
yield mock_initialize_method_map

@pytest.mark.parametrize(
"input_params, method_result, result_cmp_func, method_params",
Expand Down

0 comments on commit 6745cb8

Please sign in to comment.