Skip to content

Commit

Permalink
Add SampleGridModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Setsugennoao committed Aug 4, 2024
1 parent 66c853f commit 4340ab5
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 54 deletions.
1 change: 1 addition & 0 deletions vskernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import exceptions, kernels, util # noqa: F401, F403
from .exceptions import * # noqa: F401, F403
from .kernels import * # noqa: F401, F403
from .types import * # noqa: F401, F403
from .util import * # noqa: F401, F403
25 changes: 19 additions & 6 deletions vskernels/kernels/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from vstools.enums.color import _norm_props_enums

from ..exceptions import UnknownDescalerError, UnknownKernelError, UnknownResamplerError, UnknownScalerError
from ..types import BotFieldLeftShift, BotFieldTopShift, LeftShift, TopFieldLeftShift, TopFieldTopShift, TopShift
from ..types import (
BorderHandling, BotFieldLeftShift, BotFieldTopShift, LeftShift, SampleGridModel, TopFieldLeftShift,
TopFieldTopShift, TopShift
)

__all__ = [
'Scaler', 'ScalerT',
Expand Down Expand Up @@ -134,8 +137,8 @@ def __init_subclass__(cls) -> None:
if not _finished_loading_abstract:
return

from .complex import CustomComplexKernel
from ..util import abstract_kernels
from .complex import CustomComplexKernel

if cls in abstract_kernels:
return
Expand Down Expand Up @@ -272,25 +275,33 @@ def descale( # type: ignore[override]
shift: tuple[TopShift, LeftShift] | tuple[
TopShift | tuple[TopFieldTopShift, BotFieldTopShift],
LeftShift | tuple[TopFieldLeftShift, BotFieldLeftShift]
] = (0, 0), **kwargs: Any
] = (0, 0), *,
border_handling: BorderHandling = BorderHandling.MIRROR,
sample_grid_model: SampleGridModel = SampleGridModel.MATCH_EDGES,
field_based: FieldBased | None = None,
**kwargs: Any
) -> vs.VideoNode:
width, height = self._wh_norm(clip, width, height)

check_correct_subsampling(clip, width, height)

field_based = FieldBased.from_param_or_video(kwargs.pop('field_based', None), clip)
field_based = FieldBased.from_param_or_video(field_based, clip)

clip, bits = expect_bits(clip, 32)

de_base_args = (width, height // (1 + field_based.is_inter))
kwargs |= dict(border_handling=border_handling)

if field_based.is_inter:
shift_y, shift_x = tuple[tuple[float, float], ...](
sh if isinstance(sh, tuple) else (sh, sh) for sh in shift
)

de_kwargs_tf = self.get_descale_args(clip, (shift_y[0], shift_x[0]), *de_base_args, **kwargs)
de_kwargs_bf = self.get_descale_args(clip, (shift_y[1], shift_x[1]), *de_base_args, **kwargs)
kwargs_tf, shift = sample_grid_model.for_descale(clip, width, height, (shift_y[0], shift_x[0]), **kwargs)
kwargs_bf, shift = sample_grid_model.for_descale(clip, width, height, (shift_y[1], shift_x[1]), **kwargs)

de_kwargs_tf = self.get_descale_args(clip, (shift_y[0], shift_x[0]), *de_base_args, **kwargs_tf)
de_kwargs_bf = self.get_descale_args(clip, (shift_y[1], shift_x[1]), *de_base_args, **kwargs_bf)

if height % 2:
raise CustomIndexError('You can\'t descale to odd resolution when crossconverted!', self.descale)
Expand All @@ -311,6 +322,8 @@ def descale( # type: ignore[override]
if any(isinstance(sh, tuple) for sh in shift):
raise CustomValueError('You can\'t descale per-field when the input is progressive!', self.descale)

kwargs, shift = sample_grid_model.for_descale(clip, width, height, shift, **kwargs) # type: ignore

de_kwargs = self.get_descale_args(clip, shift, *de_base_args, **kwargs) # type: ignore

descaled = self.descale_function(clip, **_norm_props_enums(de_kwargs))
Expand Down
54 changes: 13 additions & 41 deletions vskernels/kernels/complex.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any, SupportsFloat, TypeVar, Union, cast
from math import ceil
from typing import TYPE_CHECKING, Any, SupportsFloat, TypeVar, Union, cast

from stgpytools import inject_kwargs_params
from vstools import (
CustomIntEnum, Dar, KwargsT, Resolution, Sar, VSFunctionAllArgs, check_correct_subsampling, fallback, inject_self,
padder, vs
Dar, KwargsT, Resolution, Sar, VSFunctionAllArgs, check_correct_subsampling, fallback, inject_self, vs
)

from ..types import Center, LeftShift, Slope, TopShift
from ..types import BorderHandling, Center, LeftShift, SampleGridModel, Slope, TopShift
from .abstract import Descaler, Kernel, Resampler, Scaler
from .custom import CustomKernel

__all__ = [
'BorderHandling',

'LinearScaler', 'LinearDescaler',

'KeepArScaler',
Expand All @@ -30,37 +27,6 @@
XarT = TypeVar('XarT', Sar, Dar)


class BorderHandling(CustomIntEnum):
MIRROR = 0
ZERO = 1
REPEAT = 2

def prepare_clip(self, clip: vs.VideoNode, min_pad: int = 2) -> vs.VideoNode:
pad_w, pad_h = (
self.pad_amount(size, min_pad) for size in (clip.width, clip.height)
)

if pad_w == pad_h == 0:
return clip

args = (clip, pad_w, pad_w, pad_h, pad_h)

match self:
case BorderHandling.MIRROR:
return padder.MIRROR(*args)
case BorderHandling.ZERO:
return padder.COLOR(*args)
case BorderHandling.REPEAT:
return padder.REPEAT(*args)

@lru_cache
def pad_amount(self, size: int, min_amount: int = 2) -> int:
if self is BorderHandling.MIRROR:
return 0

return (((size + min_amount) + 7) & -8) - size


def _from_param(cls: type[XarT], value: XarT | bool | float | None, fallback: XarT) -> XarT | None:
if value is False:
return fallback
Expand Down Expand Up @@ -214,6 +180,7 @@ def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int | None = None, height: int | None = None,
shift: tuple[TopShift, LeftShift] = (0, 0), *,
border_handling: BorderHandling = BorderHandling.MIRROR,
sample_grid_model: SampleGridModel = SampleGridModel.MATCH_EDGES,
sar: Sar | float | bool | None = None, dar: Dar | float | bool | None = None,
dar_in: Dar | bool | float | None = None, keep_ar: bool | None = None,
**kwargs: Any
Expand All @@ -229,6 +196,8 @@ def scale( # type: ignore[override]

kwargs, shift, out_sar = self._handle_crop_resize_kwargs(clip, width, height, shift, **kwargs)

kwargs, shift = sample_grid_model.for_scale(clip, width, height, shift, **kwargs)

padded = border_handling.prepare_clip(clip, self.kernel_radius)

shift, clip = tuple(
Expand All @@ -251,6 +220,7 @@ def scale( # type: ignore[override]
shift: tuple[TopShift, LeftShift] = (0, 0),
*,
border_handling: BorderHandling = BorderHandling.MIRROR,
sample_grid_model: SampleGridModel = SampleGridModel.MATCH_EDGES,
sar: Sar | bool | float | None = None, dar: Dar | bool | float | None = None, keep_ar: bool | None = None,
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False,
**kwargs: Any
Expand All @@ -259,7 +229,7 @@ def scale( # type: ignore[override]
return super().scale(
clip, width, height, shift, sar=sar, dar=dar, keep_ar=keep_ar,
linear=linear, sigmoid=sigmoid, border_handling=border_handling,
**kwargs
sample_grid_model=sample_grid_model, **kwargs
)


Expand All @@ -273,8 +243,10 @@ class CustomComplexKernel(CustomKernel, ComplexKernel): # type: ignore
@inject_kwargs_params
def descale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, blur: float = 1.0, border_handling: BorderHandling, ignore_mask: vs.VideoNode | None = None,
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
*, blur: float = 1.0, border_handling: BorderHandling,
sample_grid_model: SampleGridModel = SampleGridModel.MATCH_EDGES,
ignore_mask: vs.VideoNode | None = None, linear: bool = False,
sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
) -> vs.VideoNode:
...

Expand Down
81 changes: 74 additions & 7 deletions vskernels/types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,83 @@
from __future__ import annotations

from typing import TypeAlias
from functools import lru_cache
from typing import Any, TypeAlias

from vstools import CustomIntEnum, KwargsT, padder, vs

__all__ = [
'TopShift',
'LeftShift',
'TopFieldTopShift',
'TopFieldLeftShift',
'BotFieldTopShift',
'BotFieldLeftShift',
'BorderHandling', 'SampleGridModel'
]


class BorderHandling(CustomIntEnum):
MIRROR = 0
ZERO = 1
REPEAT = 2

def prepare_clip(self, clip: vs.VideoNode, min_pad: int = 2) -> vs.VideoNode:
pad_w, pad_h = (
self.pad_amount(size, min_pad) for size in (clip.width, clip.height)
)

if pad_w == pad_h == 0:
return clip

args = (clip, pad_w, pad_w, pad_h, pad_h)

match self:
case BorderHandling.MIRROR:
return padder.MIRROR(*args)
case BorderHandling.ZERO:
return padder.COLOR(*args)
case BorderHandling.REPEAT:
return padder.REPEAT(*args)

@lru_cache
def pad_amount(self, size: int, min_amount: int = 2) -> int:
if self is BorderHandling.MIRROR:
return 0

return (((size + min_amount) + 7) & -8) - size


class SampleGridModel(CustomIntEnum):
MATCH_EDGES = 0
MATCH_CENTERS = 1

def __call__(
self, width: int, height: int, src_width: int, src_height: int, shift: tuple[float, float]
) -> tuple[KwargsT, tuple[float, float]]:
kwargs = KwargsT()

if self is SampleGridModel.MATCH_CENTERS:
src_width = src_width * (width - 1) / (src_width - 1)
src_height = src_height * (height - 1) / (src_height - 1)

kwargs |= dict(src_width=src_width, src_height=src_height)
shift = tuple[float, float](
(x / 2 + y for x, y in zip(((height - src_height), (width - src_width)), shift))
)

return kwargs, shift

def for_scale(
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float], **kwargs: Any
) -> tuple[KwargsT, tuple[float, float]]:
src_width = kwargs.pop('src_width', width)
src_height = kwargs.pop('src_height', height)

return self(src_width, src_height, width, height, shift)

def for_descale(
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float], **kwargs: Any
) -> tuple[KwargsT, tuple[float, float]]:
src_width = kwargs.pop('src_width', clip.width)
src_height = kwargs.pop('src_height', clip.height)

return self(width, height, src_width, src_height, shift)


TopShift: TypeAlias = float
LeftShift: TypeAlias = float
TopFieldTopShift: TypeAlias = float
Expand Down

0 comments on commit 4340ab5

Please sign in to comment.