Skip to content

Commit

Permalink
fix mypy errors (#34)
Browse files Browse the repository at this point in the history
* fix mypy errors

* fix UnboundLocalError
  • Loading branch information
Ichunjo authored Aug 26, 2024
1 parent da3d642 commit eac7767
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 49 deletions.
34 changes: 17 additions & 17 deletions vskernels/kernels/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def _base_ensure_obj(
elif isinstance(value, cls) or isinstance(value, basecls):
new_scaler = value
else:
new_scaler = cls.from_param(value, func_except)() # type: ignore
new_scaler = cls.from_param(value, func_except)()

if new_scaler.__class__ in excluded:
raise exception_cls(
func_except or cls.ensure_obj, new_scaler.__class__, # type: ignore
func_except or cls.ensure_obj, new_scaler.__class__,
'This {cls_name} can\'t be instantiated to be used!',
cls_name=new_scaler.__class__
)
Expand Down Expand Up @@ -215,7 +215,7 @@ class Scaler(BaseScaler):

@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
def scale(
self, clip: vs.VideoNode, width: int | None = None, height: int | None = None,
shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
Expand Down Expand Up @@ -254,8 +254,8 @@ def get_scale_args(
| kwargs
)

def get_implemented_funcs(self) -> tuple[Callable[..., Any]]:
return (self.scale, self.multi) # type: ignore
def get_implemented_funcs(self) -> tuple[Callable[..., Any], ...]:
return (self.scale, self.multi)


class Descaler(BaseScaler):
Expand All @@ -270,7 +270,7 @@ class Descaler(BaseScaler):

@inject_self.cached
@inject_kwargs_params
def descale( # type: ignore[override]
def descale(
self, clip: vs.VideoNode, width: int | None, height: int | None,
shift: tuple[TopShift, LeftShift] | tuple[
TopShift | tuple[TopFieldTopShift, BotFieldTopShift],
Expand Down Expand Up @@ -324,7 +324,7 @@ def descale( # type: ignore[override]

kwargs, shift = sample_grid_model.for_src(clip, width, height, shift, **kwargs) # type: ignore

de_kwargs = self.get_descale_args(clip, shift, *de_base_args, **kwargs) # type: ignore
de_kwargs = self.get_descale_args(clip, shift, *de_base_args, **kwargs)

descaled = self.descale_function(clip, **_norm_props_enums(de_kwargs))

Expand All @@ -346,7 +346,7 @@ def get_descale_args(
| kwargs
)

def get_implemented_funcs(self) -> tuple[Callable[..., Any]]:
def get_implemented_funcs(self) -> tuple[Callable[..., Any], ...]:
return (self.descale, )


Expand Down Expand Up @@ -385,24 +385,24 @@ def get_resample_args(
| kwargs
)

def get_implemented_funcs(self) -> tuple[Callable[..., Any]]:
def get_implemented_funcs(self) -> tuple[Callable[..., Any], ...]:
return (self.resample, )


class Kernel(Scaler, Descaler, Resampler): # type: ignore
class Kernel(Scaler, Descaler, Resampler):
"""
Abstract kernel interface.
"""

_err_class = UnknownKernelError # type: ignore

@overload # type: ignore
@overload
@inject_self.cached
@inject_kwargs_params
def shift(self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any) -> vs.VideoNode:
...

@overload # type: ignore
@overload
@inject_self.cached
@inject_kwargs_params
def shift(
Expand Down Expand Up @@ -475,14 +475,14 @@ def from_param(

@overload
@classmethod
def from_param( # type: ignore
def from_param(
cls: type[Kernel], kernel: ScalerT | KernelT | None = None, func_except: FuncExceptT | None = None
) -> type[Scaler]:
...

@overload
@classmethod
def from_param( # type: ignore
def from_param(
cls: type[Kernel], kernel: DescalerT | KernelT | None = None, func_except: FuncExceptT | None = None
) -> type[Descaler]:
...
Expand Down Expand Up @@ -513,14 +513,14 @@ def ensure_obj(

@overload
@classmethod
def ensure_obj( # type: ignore
def ensure_obj(
cls: type[Kernel], kernel: ScalerT | KernelT | None = None, func_except: FuncExceptT | None = None
) -> Scaler:
...

@overload
@classmethod
def ensure_obj( # type: ignore
def ensure_obj(
cls: type[Kernel], kernel: DescalerT | KernelT | None = None, func_except: FuncExceptT | None = None
) -> Descaler:
...
Expand Down Expand Up @@ -587,7 +587,7 @@ def get_resample_args(
| self.get_params_args(False, clip, **kwargs)
)

def get_implemented_funcs(self) -> tuple[Callable[..., Any]]:
def get_implemented_funcs(self) -> tuple[Callable[..., Any], ...]:
return (self.shift, ) # type: ignore


Expand Down
2 changes: 1 addition & 1 deletion vskernels/kernels/bicubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, b: float = 0, c: float = 1 / 2, **kwargs: Any) -> None:
super().__init__(**kwargs)

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
x, b, c = abs(x), self.b, self.c

if (x < 1.0):
Expand Down
4 changes: 2 additions & 2 deletions vskernels/kernels/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def _handle_crop_resize_kwargs(
src_sar = float(_from_param(Sar, sar, Sar(1, 1)) or Sar.from_clip(clip))
out_sar = None

out_dar = float(_from_param(Dar, dar, None) or Dar.from_size(width, height)) # type: ignore
src_dar = float(fallback(_from_param(Dar, dar_in, out_dar), Dar.from_size(clip, False)))
out_dar = float(_from_param(Dar, dar, Dar(0)) or Dar.from_size(width, height))
src_dar = float(fallback(_from_param(Dar, dar_in, Dar(out_dar)), Dar.from_size(clip, False)))

if src_sar != 1.0:
if src_sar > 1.0:
Expand Down
26 changes: 15 additions & 11 deletions vskernels/kernels/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from inspect import Signature

from vstools import vs, core
from typing import Any
from typing import Any, Protocol
from .abstract import Kernel

from typing import TypeVar
Expand All @@ -14,51 +14,55 @@
]


class _kernel_func(Protocol):
def __call__(self, *, x: float) -> float:
...


class CustomKernel(Kernel):
@inject_self.cached
def kernel(self: CustomKernelT, *, x: float) -> float:
def kernel(self, *, x: float) -> float:
raise NotImplementedError

def _modify_kernel_func(self, kwargs: KwargsT):
def _modify_kernel_func(self, kwargs: KwargsT) -> tuple[_kernel_func, float]:
blur = float(kwargs.pop('blur', 1.0))
taps = int(kwargs.pop('taps', self.kernel_radius))
support = taps * blur

if blur != 1.0:
def kernel(x: float) -> float:
def kernel(*, x: float) -> float:
return self.kernel(x=x / blur)

return kernel, support

return self.kernel, support

@inject_self
def scale_function(
def scale_function( # type: ignore[override]
self, clip: vs.VideoNode, width: int | None = None, height: int | None = None, *args: Any, **kwargs: Any
) -> vs.VideoNode:
custom_kernel_vars = self._modify_kernel_func(kwargs)
kernel, support = self._modify_kernel_func(kwargs)

clean_kwargs = {
k: v for k, v in kwargs.items()
if k not in Signature.from_callable(self._modify_kernel_func).parameters.keys()
}

return core.resize2.Custom(clip, *custom_kernel_vars, width, height, *args, **clean_kwargs)
return core.resize2.Custom(clip, kernel, int(support), width, height, *args, **clean_kwargs)

resample_function = scale_function

@inject_self
def descale_function(
def descale_function( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, *args: Any, **kwargs: Any
) -> vs.VideoNode:
custom_kernel_vars = self._modify_kernel_func(kwargs)
kernel, support = self._modify_kernel_func(kwargs)

clean_kwargs = {
k: v for k, v in kwargs.items()
if k not in Signature.from_callable(self._modify_kernel_func).parameters.keys()
}

return core.descale.Decustom(clip, width, height, *custom_kernel_vars, *args, **clean_kwargs)
return core.descale.Decustom(clip, width, height, kernel, int(support), *args, **clean_kwargs)

def get_params_args(
self, is_descale: bool, clip: vs.VideoNode, width: int | None = None, height: int | None = None, **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions vskernels/kernels/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, taps: float = 2, **kwargs: Any) -> None:
self._coefs = self._splineKernelCoeff()

def _naturalCubicSpline(self, values: list[int]) -> list[float]:
import numpy as np # type: ignore
import numpy as np

n = len(values) - 1

Expand Down Expand Up @@ -85,7 +85,7 @@ def _shiftPolynomial(coeffs: list[float], shift: float) -> list[float]:
return coeffs

@inject_self.cached
def kernel(self, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
x, taps = abs(x), self.kernel_radius

if x >= taps:
Expand Down
24 changes: 12 additions & 12 deletions vskernels/kernels/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Point(CustomComplexKernel):
_static_kernel_radius = 1

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
return 1.0


Expand All @@ -63,7 +63,7 @@ class Bilinear(CustomComplexKernel):
_static_kernel_radius = 1

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
return max(1.0 - abs(x), 0.0)


Expand All @@ -78,7 +78,7 @@ def __init__(self, taps: int = 3, **kwargs: Any) -> None:
super().__init__(taps, **kwargs)

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
x, taps = abs(x), self.kernel_radius

return sinc(x) * sinc(x / taps) if x < taps else 0.0
Expand All @@ -99,7 +99,7 @@ def sigma(self) -> gauss_sigma:
return gauss_sigma(self._sigma)

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
return 1 / (self._sigma * sqrt(2 * pi)) * exp(-x ** 2 / (2 * self._sigma ** 2))


Expand All @@ -109,7 +109,7 @@ class Box(CustomComplexKernel):
_static_kernel_radius = 1

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
return 1.0 if x >= -0.5 and x < 0.5 else 0.0


Expand All @@ -125,7 +125,7 @@ def _win_coef(self, x: float) -> float:
return 0.42 + 0.50 * cos(w_x) + 0.08 * cos(w_x * 2)

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand All @@ -148,7 +148,7 @@ def __init__(self, taps: int = 4, **kwargs: Any) -> None:
super().__init__(taps, **kwargs)

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand All @@ -159,7 +159,7 @@ class Hann(CustomComplexTapsKernel):
"""Hann kernel."""

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand All @@ -170,7 +170,7 @@ class Hamming(CustomComplexTapsKernel):
"""Hamming kernel."""

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand All @@ -181,7 +181,7 @@ class Welch(CustomComplexTapsKernel):
"""Welch kernel."""

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if abs(x) >= 1.0:
return 0.0

Expand All @@ -192,7 +192,7 @@ class Cosine(CustomComplexTapsKernel):
"""Cosine kernel."""

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand All @@ -205,7 +205,7 @@ class Bohman(CustomComplexTapsKernel):
"""Bohman kernel."""

@inject_self.cached
def kernel(self, *, x: float) -> float: # type: ignore
def kernel(self, *, x: float) -> float:
if x >= self.kernel_radius:
return 0.0

Expand Down
5 changes: 3 additions & 2 deletions vskernels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,17 @@ class SampleGridModel(CustomIntEnum):
MATCH_CENTERS = 1

def __call__(
self, width: int, height: int, src_width: int, src_height: int, shift: tuple[float, float], kwargs: KwargsT
self, width: int, height: int, src_width: float, src_height: float, shift: tuple[float, float], kwargs: KwargsT
) -> tuple[KwargsT, tuple[float, float]]:
if self is SampleGridModel.MATCH_CENTERS:
src_width = src_width * (width - 1) / (src_width - 1)
src_height = src_height * (height - 1) / (src_height - 1)

kwargs |= dict(src_width=src_width, src_height=src_height)
shift = tuple[float, float](
shift_x, shift_y, *_ = tuple(
(x / 2 + y for x, y in zip(((height - src_height), (width - src_width)), shift))
)
shift = shift_x, shift_y

return kwargs, shift

Expand Down
4 changes: 2 additions & 2 deletions vskernels/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def scale( # type: ignore
) -> vs.VideoNode:
try:
width, height = Scaler._wh_norm(clip, width, height)
return super().scale(clip, clip.width, clip.height, shift, **kwargs) # type: ignore
return super().scale(clip, clip.width, clip.height, shift, **kwargs)
except Exception:
return clip

Expand Down Expand Up @@ -205,7 +205,7 @@ def out(self) -> vs.VideoNode:
def __enter__(self) -> LinearLightProcessing:
self.linear = self.linear or not not self.sigmoid

if self.sigmoid:
if self.sigmoid is not False:
if self.sigmoid is True:
self.sigmoid = (6.5, 0.75)

Expand Down

0 comments on commit eac7767

Please sign in to comment.