Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have integration points for AIP-69 in Internal API #40901

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


@functools.lru_cache
def _initialize_map() -> dict[str, Callable]:
def initialize_method_map() -> dict[str, Callable]:
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
Expand Down Expand Up @@ -148,7 +148,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
if json_rpc != "2.0":
return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400)

methods_map = _initialize_map()
methods_map = initialize_method_map()
method_name = body.get("method")
if method_name not in methods_map:
return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400)
Expand Down
24 changes: 20 additions & 4 deletions airflow/api_internal/internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from functools import wraps
from typing import Callable, TypeVar
from urllib.parse import urlparse

import requests
import tenacity
Expand Down Expand Up @@ -56,6 +57,18 @@ def force_database_direct_access():
InternalApiConfig._initialized = True
InternalApiConfig._use_internal_api = False

@staticmethod
def force_api_access(api_endpoint: str):
"""
Force using Internal API with provided endpoint.

All methods decorated with internal_api_call will always be executed remote/via API.
This mode is needed for remote setups/remote executor.
"""
InternalApiConfig._initialized = True
InternalApiConfig._use_internal_api = True
InternalApiConfig._internal_api_endpoint = api_endpoint

@staticmethod
def get_use_internal_api():
if not InternalApiConfig._initialized:
Expand All @@ -75,10 +88,13 @@ def _init_values():
raise RuntimeError("The AIP_44 is not enabled so you cannot use it.")
internal_api_endpoint = ""
if use_internal_api:
internal_api_url = conf.get("core", "internal_api_url")
internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi"
if not internal_api_endpoint.startswith("http://"):
raise AirflowConfigException("[core]internal_api_url must start with http://")
url_conf = urlparse(conf.get("core", "internal_api_url"))
api_path = url_conf.path
if len(api_path) < 2:
api_path = "/internal_api/v1/rpcapi"
potiuk marked this conversation as resolved.
Show resolved Hide resolved
if url_conf.scheme in ["http", "https"]:
raise AirflowConfigException("[core]internal_api_url must start with http:// or https://")
InternalApiConfig._internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}"

InternalApiConfig._initialized = True
InternalApiConfig._use_internal_api = use_internal_api
Expand Down
9 changes: 8 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def deref(self, dag: DAG) -> ExpandInput:
LogTemplate: LogTemplatePydantic,
Dataset: DatasetPydantic,
}
_type_to_class: dict[DAT, list] = {
_type_to_class: dict[DAT | str, list] = {
DAT.BASE_JOB: [JobPydantic, Job],
DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
DAT.DAG_RUN: [DagRunPydantic, DagRun],
Expand All @@ -433,6 +433,13 @@ def deref(self, dag: DAG) -> ExpandInput:
_class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes}


def add_pydantic_class_type_mapping(attribute_type: str, orm_class, pydantic_class):
_orm_to_model[orm_class] = pydantic_class
_type_to_class[attribute_type] = [pydantic_class, orm_class]
_class_to_type[pydantic_class] = attribute_type
_class_to_type[orm_class] = attribute_type


class BaseSerialization:
"""BaseSerialization provides utils for serialization."""

Expand Down
8 changes: 4 additions & 4 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
mock_test_method.reset_mock()
mock_test_method.side_effect = None
with mock.patch(
"airflow.api_internal.endpoints.rpc_api_endpoint._initialize_map"
) as mock_initialize_map:
mock_initialize_map.return_value = {
"airflow.api_internal.endpoints.rpc_api_endpoint.initialize_method_map"
) as mock_initialize_method_map:
mock_initialize_method_map.return_value = {
TEST_METHOD_NAME: mock_test_method,
}
yield mock_initialize_map
yield mock_initialize_method_map

@pytest.mark.parametrize(
"input_params, method_result, result_cmp_func, method_params",
Expand Down
Loading