Skip to content

Commit

Permalink
Add typing hints
Browse files Browse the repository at this point in the history
  • Loading branch information
xinghao728 committed Aug 12, 2024
1 parent 65bb1da commit 17868ac
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ release:

lint:
flake8 --exclude src/objprint/executing src/ tests/ --count --max-line-length=127 --ignore=W503
mypy src/ --exclude src/objprint/executing --follow-imports=skip
mypy src/ --exclude src/objprint/executing --follow-imports=skip --strict

test:
python -m unittest
Expand Down
14 changes: 7 additions & 7 deletions src/objprint/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@


import functools
from typing import Callable, Optional, Type, Set, Union
from typing import Any, Callable, Optional, Set, Union


def add_objprint(
orig_class: Optional[Type] = None,
format: str = "string", **kwargs) -> Union[Type, Callable[[Type], Type]]:
orig_class: Optional[type] = None,
format: str = "string", **kwargs: Any) -> Union[type, Callable[[type], type]]:

from . import _objprint

if format == "json":
import json

def __str__(self) -> str:
def __str__(self: Any) -> str:
return json.dumps(_objprint.objjson(self), **kwargs)
else:
def __str__(self) -> str:
def __str__(self: Any) -> str:
cfg = _objprint._configs.overwrite(**kwargs)
memo: Optional[Set] = set() if cfg.skip_recursion else None
memo: Optional[Set[Any]] = set() if cfg.skip_recursion else None
return _objprint._get_custom_object_str(self, memo, indent_level=0, cfg=cfg)

if orig_class is None:
def wrapper(cls: Type) -> Type:
def wrapper(cls: type) -> type:
cls.__str__ = functools.wraps(cls.__str__)(__str__) # type: ignore
return cls
return wrapper
Expand Down
2 changes: 1 addition & 1 deletion src/objprint/frame_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tokenize
from types import FrameType
from typing import List, Optional
from .executing.executing import Source # type: ignore
from .executing.executing import Source


class FrameAnalyzer:
Expand Down
32 changes: 17 additions & 15 deletions src/objprint/objprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import re
from types import FunctionType, FrameType
from typing import Any, Callable, Iterable, List, Optional, Set, TypeVar, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, TypeVar, Type, Tuple

from .color_util import COLOR, set_color
from .frame_analyzer import FrameAnalyzer
Expand All @@ -34,7 +34,7 @@ class _PrintConfig:
skip_recursion: bool = True
honor_existing: bool = True

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
for key, val in kwargs.items():
if hasattr(self, key):
if isinstance(val, type(getattr(self, key))):
Expand All @@ -44,7 +44,7 @@ def __init__(self, **kwargs):
else:
raise ValueError(f"{key} is not configurable")

def set(self, **kwargs) -> None:
def set(self, **kwargs: Any) -> None:
for key, val in kwargs.items():
if hasattr(_PrintConfig, key):
if isinstance(val, type(getattr(_PrintConfig, key))):
Expand All @@ -54,15 +54,15 @@ def set(self, **kwargs) -> None:
else:
raise ValueError(f"{key} is not configurable")

def overwrite(self, **kwargs) -> "_PrintConfig":
def overwrite(self, **kwargs: Any) -> "_PrintConfig":
ret = _PrintConfig(**kwargs)
return ret


class ObjPrint:
FormatterInfo = namedtuple('FormatterInfo', ['formatter', 'inherit'])

def __init__(self):
def __init__(self) -> None:
self._configs = _PrintConfig()

self.indicator_map = {
Expand All @@ -73,9 +73,9 @@ def __init__(self):
}
self._sys_print = print
self.frame_analyzer = FrameAnalyzer()
self.type_formatter = {}
self.type_formatter: Dict[type, ObjPrint.FormatterInfo] = {}

def __call__(self, *objs: Any, file: Any = None, format: str = "string", **kwargs) -> Any:
def __call__(self, *objs: Any, file: Any = None, format: str = "string", **kwargs: Any) -> Any:
cfg = self._configs.overwrite(**kwargs)
if cfg.enable:
# if inspect.currentframe() returns None, set call_frame to None
Expand All @@ -102,6 +102,7 @@ def __call__(self, *objs: Any, file: Any = None, format: str = "string", **kwarg

if format == "json":
if cfg.arg_name:
assert args is not None
for arg, obj in zip(args, objs):
self._sys_print(arg)
self._sys_print(json.dumps(self.objjson(obj), **kwargs))
Expand All @@ -112,6 +113,7 @@ def __call__(self, *objs: Any, file: Any = None, format: str = "string", **kwarg
# Force color with cfg as if color is not in cfg, objstr will default to False
kwargs["color"] = cfg.color
if cfg.arg_name:
assert args is not None
for arg, obj in zip(args, objs):
self._sys_print(arg)
self._sys_print(self.objstr(obj, **kwargs), file=file)
Expand All @@ -125,7 +127,7 @@ def __call__(self, *objs: Any, file: Any = None, format: str = "string", **kwarg

return objs[0] if len(objs) == 1 else objs

def objstr(self, obj: Any, **kwargs) -> str:
def objstr(self, obj: Any, **kwargs: Any) -> str:
# If no color option is specified, don't use color
if "color" not in kwargs:
kwargs["color"] = False
Expand All @@ -141,7 +143,7 @@ def _objstr(self, obj: Any, memo: Optional[Set[int]], indent_level: int, cfg: _P
if cls in self.type_formatter and (
cls == obj_type or self.type_formatter[cls].inherit
):
return self.type_formatter[cls].formatter(obj)
return self.type_formatter[cls].formatter(obj) # type: ignore

# If it's builtin type, return it directly
if isinstance(obj, str):
Expand Down Expand Up @@ -217,7 +219,7 @@ def _objjson(self, obj: Any, memo: Set[int]) -> Any:

return ret

def _get_custom_object_str(self, obj: Any, memo: Optional[Set[int]], indent_level: int, cfg: _PrintConfig):
def _get_custom_object_str(self, obj: Any, memo: Optional[Set[int]], indent_level: int, cfg: _PrintConfig) -> str:

def _get_method_line(attr: str) -> str:
if cfg.color:
Expand Down Expand Up @@ -264,7 +266,7 @@ def _get_line(key: str) -> str:

return self._get_pack_str(elems, obj, indent_level, cfg)

def _get_line_number_str(self, curr_frame: Optional[FrameType], cfg: _PrintConfig):
def _get_line_number_str(self, curr_frame: Optional[FrameType], cfg: _PrintConfig) -> str:
if curr_frame is None:
return "Unknown Line Number"
curr_code = curr_frame.f_code
Expand All @@ -279,7 +281,7 @@ def enable(self) -> None:
def disable(self) -> None:
self.config(enable=False)

def config(self, **kwargs) -> None:
def config(self, **kwargs: Any) -> None:
self._configs.set(**kwargs)

def install(self, name: str = "op") -> None:
Expand Down Expand Up @@ -325,13 +327,13 @@ def unregister_formatter(self, *obj_types: Type[Any]) -> None:
if obj_type in self.type_formatter:
del self.type_formatter[obj_type]

def get_formatter(self) -> dict:
def get_formatter(self) -> Dict[type, "ObjPrint.FormatterInfo"]:
return self.type_formatter

def _get_header_footer(self, obj: Any, cfg: _PrintConfig):
def _get_header_footer(self, obj: Any, cfg: _PrintConfig) -> Tuple[str, str]:
obj_type = type(obj)
if obj_type in self.indicator_map:
indicator = self.indicator_map[obj_type]
indicator = self.indicator_map[obj_type] # type: ignore
return indicator[0], indicator[1]
else:
if cfg.color:
Expand Down

0 comments on commit 17868ac

Please sign in to comment.