Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PEP 692: **kwargs: typing.Unpack[TypedDict] #579

Open
a-gardner1 opened this issue Sep 18, 2024 · 2 comments
Open

Add support for PEP 692: **kwargs: typing.Unpack[TypedDict] #579

a-gardner1 opened this issue Sep 18, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@a-gardner1
Copy link
Contributor

a-gardner1 commented Sep 18, 2024

🚀 Feature request

Allow using TypedDict for more precise **kwargs typing as described in PEP 692.

Motivation

I want to be able to enjoy static typing guarantees through mypy for classes or functions with TypedDict-annotated **kwargs and use those classes in configurations parsed by jsonargparse .

Right now, I either have to remove the annotation from **kwargs and hope that jsonargparse is able to inspect and infer the types using its heuristics, or I have to expand the kwargs and duplicate keywords that would otherwise be represented by the TypedDict.

Pitch

I want the following script to work without errors:

import tempfile
from dataclasses import dataclass
from typing import Any, NotRequired, Required, TypeVar, TypedDict, Unpack

import jsonargparse
import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, lazy_instance

if __name__ == '__main__':
    class TestDict(TypedDict):
        a: Required[int]
        """
        Test documentation.
        """
        b: NotRequired[int]

    class InnerTestClass:

        def __init__(self, **kwargs: Unpack[TestDict]) -> None:
            self.a = kwargs['a']
            self.b = kwargs.get('b')

    @dataclass
    class TestClass:

        test: InnerTestClass

    parser = ArgumentParser(exit_on_error=False)
    parser.add_argument(
        "-c",
        "--config",
        action=ActionConfigFile,
        help="Path to a configuration file in json or yaml format.")
    parser.add_class_arguments(
        TestClass,
        "test",
        fail_untyped=False,
        instantiate=True,
        sub_configs=True,
        default=lazy_instance(
            TestClass,
        ),
    )

    config = yaml.safe_dump(
        {
            "test":
                {
                    "test":
                        {
                            "class_path": f"{__name__}.InnerTestClass",
                            "init_args": {
                                "a": 2,
                            }
                        }
                }
        })

    with tempfile.NamedTemporaryFile("w", suffix=".yaml") as f:
        f.write(config)
        f.flush()
        cfg = parser.parse_args(["--config", f"{f.name}"])

    print(parser.dump(cfg, skip_link_targets=False, skip_none=False))

The script should print the following:

test:
  test:
    class_path: __main__.InnerTestClass
    init_args:
      a: 2

Partial Solution

The following diff from 7874273 partially provides a solution.
It does not account for earlier Python versions that do not support Unpack without typing_extensions, and it probably violates some conventions or expectations of the existing codebase.
It may also have some unintended side-effects.

Diff for Partial Solution
diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py
index 8c12b31..eda0717 100644
--- a/jsonargparse/_common.py
+++ b/jsonargparse/_common.py
@@ -16,6 +16,7 @@ from typing import (  # type: ignore[attr-defined]
     TypeVar,
     Union,
     _GenericAlias,
+    _UnpackGenericAlias,
 )
 
 from ._namespace import Namespace
@@ -102,6 +103,10 @@ def is_generic_class(cls) -> bool:
     return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing"
 
 
+def is_unpack_typehint(cls) -> bool:
+    return isinstance(cls, _UnpackGenericAlias)
+
+
 def get_generic_origin(cls):
     return cls.__origin__ if is_generic_class(cls) else cls
 
diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py
index 9ec653b..a216c3d 100644
--- a/jsonargparse/_core.py
+++ b/jsonargparse/_core.py
@@ -1317,7 +1317,10 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
                 keys.append(action_dest)
             elif getattr(action, "jsonnet_ext_vars", False):
                 prev_cfg[action_dest] = value
-            cfg[action_dest] = value
+            if value == inspect._empty:
+                cfg.pop(action_dest, None)
+            else:
+                cfg[action_dest] = value
         return cfg[parent_key] if parent_key else cfg
 
     def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
@@ -1335,6 +1338,7 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
         with parser_context(parent_parser=self):
             ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
         ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
+        ActionTypeHint.delete_not_required_args(cfg_from, cfg_to)
         cfg_to.update(cfg_from)
         ActionTypeHint.apply_appends(self, cfg_to)
         return cfg_to
diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py
index 8279fc7..2df199b 100644
--- a/jsonargparse/_parameter_resolvers.py
+++ b/jsonargparse/_parameter_resolvers.py
@@ -20,6 +20,7 @@ from ._common import (
     is_dataclass_like,
     is_generic_class,
     is_subclass,
+    is_unpack_typehint,
     parse_logger,
 )
 from ._optionals import get_annotated_base_type, is_annotated, is_pydantic_model, parse_docs
@@ -28,6 +29,7 @@ from ._stubs_resolver import get_stub_types
 from ._util import (
     ClassFromFunctionBase,
     get_import_path,
+    get_typehint_args,
     get_typehint_origin,
     iter_to_set_str,
     unique,
@@ -328,6 +330,38 @@ def replace_generic_type_vars(params: ParamList, parent) -> None:
             param.annotation = replace_type_vars(param.annotation)
 
 
+def unpack_typed_dict_kwargs(params: ParamList) -> bool:
+    kwargs_idx = get_arg_kind_index(params, kinds.VAR_KEYWORD)
+    if kwargs_idx >= 0:
+        kwargs = params.pop(kwargs_idx)
+        annotation = kwargs.annotation
+        if is_unpack_typehint(annotation):
+            annotation_args = get_typehint_args(annotation)
+            assert len(annotation_args) == 1, "Unpack requires a single type argument"
+            dict_annotations = annotation_args[0].__annotations__
+            new_params = []
+            for nm, annot in dict_annotations.items():
+                new_params.append(ParamData(
+                    name=nm,
+                    annotation=annot,
+                    default=inspect._empty,
+                    kind=inspect._ParameterKind.KEYWORD_ONLY,
+                    doc=None,
+                    component=kwargs.component,
+                    parent=kwargs.parent,
+                    origin=kwargs.origin
+                ))
+            # insert in-place
+            trailing_params = []  # expected to be empty
+            for _ in range(kwargs_idx, len(params)):
+                trailing_params.append(params.pop(kwargs_idx))
+            params.extend(new_params)
+            params.extend(trailing_params)
+            return True
+    return False
+
+
+
 def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component) -> None:
     if not stubs:
         return
@@ -848,12 +882,16 @@ class ParametersVisitor(LoggerProperty, ast.NodeVisitor):
             self.component, self.parent, self.logger
         )
         self.replace_param_default_subclass_specs(params)
+        if unpack_typed_dict_kwargs(params):
+            kwargs_idx = -1
         if args_idx >= 0 or kwargs_idx >= 0:
             self.doc_params = doc_params
             with mro_context(self.parent):
                 args, kwargs = self.get_parameters_args_and_kwargs()
             params = replace_args_and_kwargs(params, args, kwargs)
         add_stub_types(stubs, params, self.component)
+        # in case a typed-dict kwarg typehint is inherited
+        unpack_typed_dict_kwargs(params)
         params = self.remove_ignore_parameters(params)
         return params
 
@@ -865,6 +903,8 @@ def get_parameters_by_assumptions(
 ) -> ParamList:
     component, parent, method_name = get_component_and_parent(function_or_class, method_name)
     params, args_idx, kwargs_idx, _, stubs = get_signature_parameters_and_indexes(component, parent, logger)
+    if unpack_typed_dict_kwargs(params):
+        kwargs_idx = -1
 
     if parent and (args_idx >= 0 or kwargs_idx >= 0):
         with mro_context(parent):
@@ -875,6 +915,8 @@ def get_parameters_by_assumptions(
 
     params = replace_args_and_kwargs(params, [], [])
     add_stub_types(stubs, params, component)
+    # in case a typed-dict kwarg typehint is inherited
+    unpack_typed_dict_kwargs(params)
     return params
 
 
diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py
index 807a8d4..7d75e19 100644
--- a/jsonargparse/_signatures.py
+++ b/jsonargparse/_signatures.py
@@ -29,8 +29,9 @@ from ._typehints import (
     callable_instances,
     get_subclass_names,
     is_optional,
+    not_required_types,
 )
-from ._util import NoneType, get_private_kwargs, iter_to_set_str
+from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
 from .typing import register_pydantic_type
 
 __all__ = [
@@ -322,7 +323,7 @@ class SignatureArguments(LoggerProperty):
             default = param.default
             if default == inspect_empty and is_optional(annotation):
                 default = None
-        is_required = default == inspect_empty
+        is_required = default == inspect_empty and get_typehint_origin(annotation) not in not_required_types
         src = get_parameter_origins(param.component, param.parent)
         skip_message = f'Skipping parameter "{name}" from "{src}" because of: '
         if not fail_untyped and annotation == inspect_empty:
diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py
index 50a119b..e67c14f 100644
--- a/jsonargparse/_typehints.py
+++ b/jsonargparse/_typehints.py
@@ -439,6 +439,13 @@ class ActionTypeHint(Action):
                         if skip_key in parser.required_args:
                             del val.init_args[skip_key]
 
+    @staticmethod
+    def delete_not_required_args(cfg_from, cfg_to):
+        for key, val in list(cfg_to.items(branches=True)):
+            if val == inspect._empty and key not in cfg_from:
+                del cfg_to[key]
+
+
     @staticmethod
     @contextmanager
     def subclass_arg_context(parser):
@@ -587,6 +594,8 @@ class ActionTypeHint(Action):
                     assert ex  # needed due to ruff bug that removes " as ex"
                     if orig_val == "-" and isinstance(getattr(ex, "parent", None), PathError):
                         raise ex
+                    if get_typehint_origin(self._typehint) in not_required_types and val == inspect._empty:
+                        ex = None
                     try:
                         if isinstance(orig_val, str):
                             with change_to_path_dir(config_path):
@@ -943,6 +952,7 @@ def adapt_typehints(
     # TypedDict NotRequired and Required
     elif typehint_origin in not_required_required_types:
         assert len(subtypehints) == 1, "(Not)Required requires a single type argument"
         val = adapt_typehints(val, subtypehints[0], **adapt_kwargs)
 
     # Callable
diff --git a/jsonargparse/_util.py b/jsonargparse/_util.py
index e97ea2a..3c3d6c7 100644
--- a/jsonargparse/_util.py
+++ b/jsonargparse/_util.py
@@ -268,6 +268,10 @@ def object_path_serializer(value):
         raise ValueError(f"Only possible to serialize an importable object, given {value}: {ex}") from ex
 
 
+def get_typehint_args(typehint):
+    return getattr(typehint, "__args__", tuple())
+
+
 def get_typehint_origin(typehint):
     if not hasattr(typehint, "__origin__"):
         typehint_class = get_import_path(typehint.__class__)
@a-gardner1 a-gardner1 added the enhancement New feature or request label Sep 18, 2024
@mauvilsa
Copy link
Member

I either have to remove the annotation from **kwargs

Why do you have to remove the annotation? Keeping it makes something fail currently?

Regarding the support for **kwargs: typing.Unpack[TypedDict], it sounds good. If someone makes the effort to write a TypedDict, then it can be assumed that by default this should be trusted and avoid doing code introspection.

I didn't look at the diff in detail. I do think I will propose changes. But please go ahead and create a pull request.

@a-gardner1
Copy link
Contributor Author

Ah, yes. I should have included the error from running the given script. It should yield an error claiming that a is not a valid argument.

Unfortunately, I won't be able to get to a computer where I can work on this for several days, but I will open a PR as soon as I am able.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants