Skip to content

Commit

Permalink
Add border_handling to ComplexKernel.scale
Browse files Browse the repository at this point in the history
  • Loading branch information
Setsugennoao committed Dec 14, 2023
1 parent 4a2a729 commit 87f93fd
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions vskernels/kernels/complex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any, SupportsFloat, TypeVar, Union, cast

from stgpytools import inject_kwargs_params
from vstools import (
CustomIntEnum, Dar, KwargsT, Resolution, Sar, VSFunctionAllArgs, check_correct_subsampling, inject_self, vs
CustomIntEnum, Dar, KwargsT, Resolution, Sar, VSFunctionAllArgs, check_correct_subsampling, inject_self, padder, vs
)

from ..types import Center, LeftShift, Slope, TopShift
Expand All @@ -29,6 +30,31 @@ class BorderHandling(CustomIntEnum):
ZERO = 1
REPEAT = 2

def prepare_clip(self, clip: vs.VideoNode, min_pad: int = 2) -> vs.VideoNode:
pad_w, pad_h = (
self.pad_amount(size, min_pad) for size in (clip.width, clip.height)
)

if pad_w == pad_h == 0:
return clip

args = (clip, pad_w, pad_w, pad_h, pad_h)

match self:
case BorderHandling.MIRROR:
return padder.MIRROR(*args)
case BorderHandling.ZERO:
return padder.COLOR(*args)
case BorderHandling.REPEAT:
return padder.REPEAT(*args)

@lru_cache
def pad_amount(self, size: int, min_amount: int = 2) -> int:
if self is BorderHandling.MIRROR:
return 0

return (((size + min_amount) + 7) & -8) - size


def _from_param(cls: type[XarT], value: XarT | bool | float | None, fallback: XarT) -> XarT | None:
if value is False:
Expand Down Expand Up @@ -72,7 +98,7 @@ def func(
resampler: Resampler | None = self if isinstance(self, Resampler) else None

with LinearLight(clip, linear, sigmoid, resampler, kwargs.pop('format', None)) as ll:
ll.linear = operation(ll.linear, width, height, shift, **kwargs)
ll.linear = operation(ll.linear, width, height, shift, **kwargs) # type: ignore

return ll.out

Expand Down Expand Up @@ -174,6 +200,7 @@ def _handle_crop_resize_kwargs( # type: ignore[override]
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), *,
border_handling: BorderHandling = BorderHandling.MIRROR,
sar: Sar | float | bool | None = None, dar: Dar | float | bool | None = None, keep_ar: bool = False,
**kwargs: Any
) -> vs.VideoNode:
Expand All @@ -186,6 +213,12 @@ def scale( # type: ignore[override]

kwargs, shift, out_sar = self._handle_crop_resize_kwargs(clip, width, height, shift, **kwargs)

padded = border_handling.prepare_clip(clip, self.kernel_radius)

shift = tuple( # type: ignore
s + p - c for s, p, c in zip(shift, *((x.width, x.height) for x in (clip, padded)))
)

kwargs = self.get_scale_args(clip, shift, width, height, **kwargs)

clip = self.scale_function(clip, **kwargs)
Expand All @@ -202,12 +235,15 @@ class ComplexScaler(LinearScaler, KeepArScaler):
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*,
border_handling: BorderHandling = BorderHandling.MIRROR,
sar: Sar | bool | float | None = None, dar: Dar | bool | float | None = None, keep_ar: bool = False,
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False,
**kwargs: Any
) -> vs.VideoNode:
return super().scale(
clip, width, height, shift, sar=sar, dar=dar, keep_ar=keep_ar, linear=linear, sigmoid=sigmoid, **kwargs
clip, width, height, shift, sar=sar, dar=dar, keep_ar=keep_ar,
linear=linear, sigmoid=sigmoid, border_handling=border_handling,
**kwargs
)


Expand Down

0 comments on commit 87f93fd

Please sign in to comment.