forked from kornia/kornia
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] Added dissolving transformation & updated docs (kornia#2961)
* updated * update * update * update * Added tests for LazyLoader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated typing * Added tests * Updated caching for diffusers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * updated * update * update * update * update * update * update * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * update * update * updated README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * updated * update * update * Update README.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edgar Riba <edgar.riba@gmail.com>
- Loading branch information
1 parent
06bc50d
commit 74a9742
Showing
19 changed files
with
466 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,34 @@ | ||
import argparse | ||
import os | ||
|
||
import diffusers | ||
import torch | ||
|
||
fonts = { | ||
"sold2_wireframe": "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth", | ||
models = { | ||
"sold2_wireframe": ("torchhub", "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"), | ||
"stabilityai/stable-diffusion-2-1": ("diffusers", "StableDiffusionPipeline"), | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("WeightsDownloader") | ||
parser.add_argument("--target_directory", "-t", required=False, default="target_directory") | ||
|
||
args = parser.parse_args() | ||
|
||
torch.hub.set_dir(args.target_directory) | ||
# For HuggingFace model caching | ||
os.environ["HF_HOME"] = args.target_directory | ||
|
||
for name, url in fonts.items(): | ||
print(f"Downloading weights of `{name}` from `url`. Caching to dir `{args.target_directory}`") | ||
torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu")) | ||
for name, (src, path) in models.items(): | ||
if src == "torchhub": | ||
print(f"Downloading weights of `{name}` from `{path}`. Caching to dir `{args.target_directory}`") | ||
torch.hub.load_state_dict_from_url(path, model_dir=args.target_directory, map_location=torch.device("cpu")) | ||
elif src == "diffusers": | ||
print(f"Downloading `{name}` from diffusers. Caching to dir `{args.target_directory}`") | ||
if path == "StableDiffusionPipeline": | ||
diffusers.StableDiffusionPipeline.from_pretrained( | ||
name, cache_dir=args.target_directory, device_map="balanced" | ||
) | ||
|
||
raise SystemExit(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
from kornia.augmentation import random_generator as rg | ||
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D | ||
from kornia.core import Tensor | ||
from kornia.filters import StableDiffusionDissolving | ||
|
||
|
||
class RandomDissolving(IntensityAugmentationBase2D): | ||
r"""Perform dissolving transformation using StableDiffusion models. | ||
Based on :cite:`shi2024dissolving`, the dissolving transformation is essentially applying one-step | ||
reverse diffusion. Our implementation currently supports HuggingFace implementations of SD 1.4, 1.5 | ||
and 2.1. SD 1.X tends to remove more details than SD2.1. | ||
.. list-table:: Title | ||
:widths: 32 32 32 | ||
:header-rows: 1 | ||
* - SD 1.4 | ||
- SD 1.5 | ||
- SD 2.1 | ||
* - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.4.png | ||
- figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.5.png | ||
- figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-2.1.png | ||
Args: | ||
p: probability of applying the transformation. | ||
version: the version of the stable diffusion model. | ||
step_range: the step range of the diffusion model steps. Higher the step, stronger | ||
the dissolving effects. | ||
keepdim: whether to keep the output shape the same as input (True) or broadcast it | ||
to the batch form (False). | ||
**kwargs: additional arguments for `.from_pretrained` for HF StableDiffusionPipeline. | ||
Shape: | ||
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`. | ||
- Output: :math:`(B, C, H, W)` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
step_range: Tuple[float, float] = (100, 500), | ||
version: str = "2.1", | ||
p: float = 0.5, | ||
keepdim: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(p=p, same_on_batch=True, keepdim=keepdim) | ||
self.step_range = step_range | ||
self._dslv = StableDiffusionDissolving(version, **kwargs) | ||
self._param_generator = rg.PlainUniformGenerator((self.step_range, "step_range_factor", None, None)) | ||
|
||
def apply_transform( | ||
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None | ||
) -> Tensor: | ||
return self._dslv(input, params["step_range_factor"][0].long().item()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.