diff --git a/.gitignore b/.gitignore index 32c9bf88d..c95a4cd30 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ cache *.lock *.zip *.rar +*.7z *.pyc /*.bat /*.sh diff --git a/.vscode/settings.json b/.vscode/settings.json index d5e8da26c..2b6588ae9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,5 +10,6 @@ "./repositories/stable-diffusion-stability-ai", "./repositories/stable-diffusion-stability-ai/ldm" ], - "python.analysis.typeCheckingMode": "off" -} + "python.analysis.typeCheckingMode": "off", + "editor.formatOnSave": false +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index d62d893b7..42b738dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,146 @@ # Change Log for SD.Next +## Update for 2023-12-29 + +- **Control** + - native implementation of all image control methods: + **ControlNet**, **ControlNet XS**, **Control LLLite**, **T2I Adapters** and **IP Adapters** + - top-level **Control** next to **Text** and **Image** generate + - supports all variations of **SD15** and **SD-XL** models + - supports *Text*, *Image*, *Batch* and *Video* processing + - for details and list of supported models and workflows, see Wiki documentation: + +- **Diffusers** + - [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega) model support + - small and fast version of **SDXL**, only 3.1GB in size! + - select from *networks -> reference* + - [aMUSEd 256](https://huggingface.co/amused/amused-256) and [aMUSEd 512](https://huggingface.co/amused/amused-512) model support + - lightweigt models that excel at fast image generation + - *note*: must select: settings -> diffusers -> generator device: unset + - select from *networks -> reference* + - [Playground v1](https://huggingface.co/playgroundai/playground-v1), [Playground v2 256](https://huggingface.co/playgroundai/playground-v2-256px-base), [Playground v2 512](https://huggingface.co/playgroundai/playground-v2-512px-base), [Playground v2 1024](https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic) model support + - comparable to SD15 and SD-XL, trained from scratch for highly aesthetic images + - simply select from *networks -> reference* and use as usual + - [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/) + - img2img model that can replace subjects in images using prompt keywords + - download and load by selecting from *networks -> reference -> blip diffusion* + - in image tab, select `blip diffusion` script + - [DemoFusion](https://github.com/PRIS-CV/DemoFusion) run your SDXL generations at any resolution! + - in **Text** tab select *script* -> *demofusion* + - *note*: GPU VRAM limits do not automatically go away so be careful when using it with large resolutions + in the future, expect more optimizations, especially related to offloading/slicing/tiling, + but at the moment this is pretty much experimental-only + - [AnimateDiff](https://github.com/guoyww/animatediff/) + - overall improved quality + - can now be used with *second pass* - enhance, upscale and hires your videos! + - [IP Adapter](https://github.com/tencent-ailab/IP-Adapter) + - add support for **ip-adapter-plus_sd15, ip-adapter-plus-face_sd15 and ip-adapter-full-face_sd15** + - can now be used in *xyz-grid* + - **Text-to-Video** + - in text tab, select `text-to-video` script + - supported models: **ModelScope v1.7b, ZeroScope v1, ZeroScope v1.1, ZeroScope v2, ZeroScope v2 Dark, Potat v1** + *if you know of any other t2v models youd like to see supported, let me know!* + - models are auto-downloaded on first use + - *note*: current base model will be unloaded to free up resources + - **Prompt scheduling** now implemented for Diffusers backend, thanks @AI-Casanova + - **Custom pipelines** contribute by adding your own custom pipelines! + - for details, see fully documented example: + + - **Schedulers** + - add timesteps range, changing it will make scheduler to be over-complete or under-complete + - add rescale betas with zero SNR option (applicable to Euler, Euler a and DDIM, allows for higher dynamic range) + - **Inpaint** + - improved quality when using mask blur and padding + - **UI** + - 3 new native UI themes: **orchid-dreams**, **emerald-paradise** and **timeless-beige**, thanks @illu_Zn + - more dynamic controls depending on the backend (original or diffusers) + controls that are not applicable in current mode are now hidden + - allow setting of resize method directly in image tab + (previously via settings -> upscaler_for_img2img) +- **Optional** + - **FaceID** face guidance during generation + - also based on IP adapters, but with additional face detection and external embeddings calculation + - calculates face embeds based on input image and uses it to guide generation + - simply select from *scripts -> faceid* + - *experimental module*: requirements must be installed manually: + > pip install insightface ip_adapter + - **Depth 3D** image to 3D scene + - delivered as an extension, install from extensions tab + + - creates fully compatible 3D scene from any image by using depth estimation + and creating a fully populated mesh + - scene can be freely viewed in 3D in the UI itself or downloaded for use in other applications + - [ONNX/Olive](https://github.com/vladmandic/automatic/wiki/ONNX-Olive) + - major work continues in olive branch, see wiki for details, thanks @lshqqytiger + as a highlight, 4-5 it/s using DirectML on AMD GPU translates to 23-25 it/s using ONNX/Olive! +- **General** + - new **onboarding** + - if no models are found during startup, app will no longer ask to download default checkpoint + instead, it will show message in UI with options to change model path or download any of the reference checkpoints + - *extra networks -> models -> reference* section is now enabled for both original and diffusers backend + - support for **Torch 2.1.2** (release) and **Torch 2.3** (dev) + - **Process** create videos from batch or folder processing + supports *GIF*, *PNG* and *MP4* with full interpolation, scene change detection, etc. + - **LoRA** + - add support for block weights, thanks @AI-Casanova + example `` + - add support for LyCORIS GLora networks + - add support for LoRA PEFT (*Diffusers*) networks + - add support for Lora-OFT (*Kohya*) and Lyco-OFT (*Kohaku*) networks + - reintroduce alternative loading method in settings: `lora_force_diffusers` + - add support for `lora_fuse_diffusers` if using alternative method + use if you have multiple complex loras that may be causing performance degradation + as it fuses lora with model during load instead of interpreting lora on-the-fly + - **CivitAI downloader** allow usage of access tokens for download of gated or private models + - **Extra networks** new *settting -> extra networks -> build info on first access* + indexes all networks on first access instead of server startup + - **IPEX**, thanks @disty0 + - update to **Torch 2.1** + if you get file not found errors, set `DISABLE_IPEXRUN=1` and run the webui with `--reinstall` + - built-in *MKL* and *DPCPP* for IPEX, no need to install OneAPI anymore + - **StableVideoDiffusion** is now supported with IPEX + - **8 bit support with NNCF** on Diffusers backend + - fix IPEX Optimize not applying with Diffusers backend + - disable 32bit workarounds if the GPU supports 64bit + - add `DISABLE_IPEXRUN` and `DISABLE_IPEX_1024_WA` environment variables + - performance and compatibility improvements + - **OpenVINO**, thanks @disty0 + - **8 bit support for CPUs** + - reduce System RAM usage + - update to Torch 2.1.2 + - add *Directory for OpenVINO cache* option to *System Paths* + - remove Intel ARC specific 1024x1024 workaround + - **HDR controls** + - batch-aware for enhancement of multiple images or video frames + - available in image tab + - **Logging** + - additional *TRACE* logging enabled via specific env variables + see for details + - improved profiling + use with `--debug --profile` + - log output file sizes + - **Other** + - **API** several minor but breaking changes to API behavior to better align response fields, thanks @Trojaner + - **Inpaint** add option `apply_overlay` to control if inpaint result should be applied as overlay or as-is + can remove artifacts and hard edges of inpaint area but also remove some details from original + - **chaiNNer** fix `NaN` issues due to autocast + - **Upscale** increase limit from 4x to 8x given the quality of some upscalers + - **Extra Networks** fix sort + - reduced default **CFG scale** from 6 to 4 to be more out-of-the-box compatibile with LCM/Turbo models + - disable google fonts check on server startup + - fix torchvision/basicsr compatibility + - fix styles quick save + - add hdr settings to metadata + - improve handling of long filenames and filenames during batch processing + - do not set preview samples when using via api + - avoid unnecessary resizes in img2img and inpaint + - safe handling of config updates avoid file corruption on I/O errors + - updated `cli/simple-txt2img.py` and `cli/simple-img2img.py` scripts + - save `params.txt` regardless of image save status + - update built-in log monitor in ui, thanks @midcoastal + - major CHANGELOG doc cleanup, thanks @JetVarimax + - major INSTALL doc cleanup, thanks JetVarimax + ## Update for 2023-12-04 Whats new? Native video in SD.Next via both **AnimateDiff** and **Stable-Video-Diffusion** - and including native MP4 encoding and smooth video outputs out-of-the-box, not just animated-GIFs. diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..f7fd4bba3 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,28 @@ +cff-version: 1.2.0 +title: SD.Next +url: 'https://github.com/vladmandic/automatic' +message: >- + If you use this software, please cite it using the + metadata from this file +type: software +authors: + - given-names: Vladimir + name-particle: Vlado + family-names: Mandic + orcid: 'https://orcid.org/0009-0003-4592-5074' +identifiers: + - type: url + value: 'https://github.com/vladmandic' + description: GitHub + - type: url + value: 'https://www.linkedin.com/in/cyan051/' + description: LinkedIn +repository-code: 'https://github.com/vladmandic/automatic' +abstract: >- + SD.Next: Advanced Implementation of Stable Diffusion and + other diffusion models for text, image and video + generation +keywords: + - stablediffusion diffusers sdnext +license: AGPL-3.0 +date-released: 2022-12-24 diff --git a/README.md b/README.md index 392542840..7c3c49add 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,9 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG - Multiple backends! ▹ **Original | Diffusers** - Multiple diffusion models! - ▹ **Stable Diffusion | SD-XL | LCM | Segmind | Kandinsky | Pixart-α | Würstchen | DeepFloyd IF | UniDiffusion | SD-Distilled | etc.** + ▹ **Stable Diffusion 1.5/2.1 | SD-XL | LCM | Segmind | Kandinsky | Pixart-α | Würstchen | aMUSEd | DeepFloyd IF | UniDiffusion | SD-Distilled | BLiP Diffusion | etc.** +- Built-in Control for Text, Image, Batch and video processing! + ▹ **ControlNet | ControlNet XS | Control LLLite | T2I Adapters | IP Adapters** - Multiplatform! ▹ **Windows | Linux | MacOS with CPU | nVidia | AMD | IntelArc | DirectML | OpenVINO | ONNX+Olive** - Platform specific autodetection and tuning performed on install @@ -28,7 +30,6 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG - Improved prompt parser - Enhanced *Lora*/*LoCon*/*Lyco* code supporting latest trends in training - Built-in queue management -- Advanced metadata caching and handling to speed up operations - Enterprise level logging and hardened API - Modern localization and hints engine - Broad compatibility with existing extensions ecosystem and new extensions manager @@ -37,7 +38,8 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG
-![Screenshot-Dark](html/black-teal.jpg) +![Screenshot-Dark](html/xmas-default.jpg) +![Screenshot-Control](html/xmas-control.jpg) ![Screenshot-Light](html/light-teal.jpg)
@@ -58,17 +60,23 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG Additional models will be added as they become available and there is public interest in them -- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)* -- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models) -- [StabilityAI Stable Video Diffusion Base and XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) -- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B) -- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models) -- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0* -- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large* +- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)* +- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models) +- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base and XT +- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models) +- [aMUSEd 256](https://huggingface.co/amused/amused-256) 256 and 512 +- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega) +- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B) +- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0* +- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large* - [Warp Wuerstchen](https://huggingface.co/blog/wuertschen) +- [Playground](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024* - [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser) - [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large* +- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b) - [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)* +- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/) + Also supported are modifiers such as: - **LCM** and **Turbo** (Adversarial Diffusion Distillation) networks @@ -209,6 +217,9 @@ General goals: ### **Docs** +If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there, +check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it + - [Wiki](https://github.com/vladmandic/automatic/wiki) - [ReadMe](README.md) - [ToDo](TODO.md) diff --git a/SECURITY.md b/SECURITY.md index af9687f8c..9c1e11bb0 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -32,5 +32,5 @@ Any code commit is validated before merge - Download extensions and themes indexes from automatically updated indexes - Download required packages and repositories from GitHub during installation/upgrade - Download installed/enabled extensions -- Download default model from official repository +- Download models from CivitAI and/or Huggingface when instructed by user - Submit benchmark info upon user interaction diff --git a/cli/image-grid.py b/cli/image-grid.py index a1fd12d17..743a12bdf 100755 --- a/cli/image-grid.py +++ b/cli/image-grid.py @@ -56,7 +56,7 @@ def grid(images, labels = None, width = 0, height = 0, border = 0, square = Fals for i, img in enumerate(images): # pylint: disable=redefined-outer-name x = (i % cols * w) + (i % cols * border) y = (i // cols * h) + (i // cols * border) - img.thumbnail((w, h), Image.HAMMING) + img.thumbnail((w, h), Image.Resampling.HAMMING) image.paste(img, box=(x, y)) if labels is not None and len(images) == len(labels): ctx = ImageDraw.Draw(image) diff --git a/cli/simple-img2img.py b/cli/simple-img2img.py index 8590fc62a..41043fc93 100755 --- a/cli/simple-img2img.py +++ b/cli/simple-img2img.py @@ -1,9 +1,10 @@ #!/usr/bin/env python import os import io -import sys +import time import base64 import logging +import argparse import requests import urllib3 from PIL import Image @@ -14,22 +15,13 @@ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s') log = logging.getLogger(__name__) +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +filename='/tmp/simple-img2img.jpg' options = { - "init_images": [], - "prompt": "city at night", - "negative_prompt": "foggy, blurry", - "steps": 20, - "batch_size": 1, - "n_iter": 1, - "seed": -1, - "sampler_name": "Euler a", - "cfg_scale": 6, - "width": 512, - "height": 512, "save_images": False, "send_images": True, } -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) def auth(): @@ -51,26 +43,51 @@ def encode(f): image = image.convert('RGB') with io.BytesIO() as stream: image.save(stream, 'JPEG') + image.close() values = stream.getvalue() encoded = base64.b64encode(values).decode() return encoded -def generate(num: int = 0): - log.info(f'sending generate request: {num+1} {options}') - options['init_images'] = [encode('html/logo-dark.png')] - options['batch_size'] = len(options['init_images']) +def generate(args): # pylint: disable=redefined-outer-name + t0 = time.time() + if args.model is not None: + post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model }) + post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load + options['prompt'] = args.prompt + options['negative_prompt'] = args.negative + options['steps'] = int(args.steps) + options['seed'] = int(args.seed) + options['sampler_name'] = args.sampler + options['init_images'] = [encode(args.init)] + image = Image.open(args.init) + options['width'] = image.width + options['height'] = image.height + image.close() + if args.mask is not None: + options['mask'] = encode(args.mask) data = post('/sdapi/v1/img2img', options) + t1 = time.time() if 'images' in data: for i in range(len(data['images'])): b64 = data['images'][i].split(',',1)[0] + info = data['info'] image = Image.open(io.BytesIO(base64.b64decode(b64))) - log.info(f'received image: {image.size}') + image.save(filename) + log.info(f'received image: size={image.size} file={filename} time={t1-t0:.2f} info="{info}"') else: log.warning(f'no images received: {data}') + if __name__ == "__main__": - sys.argv.pop(0) - repeats = int(''.join(sys.argv) or '1') - log.info(f'repeats: {repeats}') - for n in range(repeats): - generate(n) + parser = argparse.ArgumentParser(description = 'simple-img2img') + parser.add_argument('--init', required=True, help='init image') + parser.add_argument('--mask', required=False, help='mask image') + parser.add_argument('--prompt', required=False, default='', help='prompt text') + parser.add_argument('--negative', required=False, default='', help='negative prompt text') + parser.add_argument('--steps', required=False, default=20, help='number of steps') + parser.add_argument('--seed', required=False, default=-1, help='initial seed') + parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name') + parser.add_argument('--model', required=False, help='model name') + args = parser.parse_args() + log.info(f'img2img: {args}') + generate(args) diff --git a/cli/simple-txt2img.py b/cli/simple-txt2img.py index 70e60a916..c2a5ee001 100755 --- a/cli/simple-txt2img.py +++ b/cli/simple-txt2img.py @@ -1,9 +1,10 @@ #!/usr/bin/env python import io import os -import sys +import time import base64 import logging +import argparse import requests import urllib3 from PIL import Image @@ -17,18 +18,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) filename='/tmp/simple-txt2img.jpg' -model = None # desired model name, will be set if not none options = { - "prompt": "city at night", - "negative_prompt": "foggy, blurry", - "steps": 20, - "batch_size": 1, - "n_iter": 1, - "seed": -1, - "sampler_name": "UniPC", - "cfg_scale": 6, - "width": 512, - "height": 512, "save_images": False, "send_images": True, } @@ -48,25 +38,41 @@ def post(endpoint: str, dct: dict = None): return req.json() -def generate(num: int = 0): - log.info(f'sending generate request: {num+1} {options}') - if model is not None: - post('/sdapi/v1/options', { 'sd_model_checkpoint': model }) +def generate(args): # pylint: disable=redefined-outer-name + t0 = time.time() + if args.model is not None: + post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model }) post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load + options['prompt'] = args.prompt + options['negative_prompt'] = args.negative + options['steps'] = int(args.steps) + options['seed'] = int(args.seed) + options['sampler_name'] = args.sampler + options['width'] = int(args.width) + options['height'] = int(args.height) data = post('/sdapi/v1/txt2img', options) + t1 = time.time() if 'images' in data: for i in range(len(data['images'])): b64 = data['images'][i].split(',',1)[0] image = Image.open(io.BytesIO(base64.b64decode(b64))) + info = data['info'] image.save(filename) - log.info(f'received image: size={image.size} file={filename}') + log.info(f'received image: size={image.size} file={filename} time={t1-t0:.2f} info="{info}"') else: log.warning(f'no images received: {data}') if __name__ == "__main__": - sys.argv.pop(0) - repeats = int(''.join(sys.argv) or '1') - log.info(f'repeats: {repeats}') - for n in range(repeats): - generate(n) + parser = argparse.ArgumentParser(description = 'simple-txt2img') + parser.add_argument('--prompt', required=False, default='', help='prompt text') + parser.add_argument('--negative', required=False, default='', help='negative prompt text') + parser.add_argument('--width', required=False, default=512, help='image width') + parser.add_argument('--height', required=False, default=512, help='image height') + parser.add_argument('--steps', required=False, default=20, help='number of steps') + parser.add_argument('--seed', required=False, default=-1, help='initial seed') + parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name') + parser.add_argument('--model', required=False, help='model name') + args = parser.parse_args() + log.info(f'txt2img: {args}') + generate(args) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index eb9a40b0d..6cdfd03a8 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -35,8 +35,11 @@ def activate(self, p, params_list): names.append(params.positional[0]) te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 te_multiplier = float(params.named.get("te", te_multiplier)) - unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier - unet_multiplier = float(params.named.get("unet", unet_multiplier)) + unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3 + unet_multiplier = [float(params.named.get("unet", unet_multiplier[0]))] * 3 + unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0])) + unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1])) + unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2])) dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim te_multipliers.append(te_multiplier) @@ -59,13 +62,15 @@ def activate(self, p, params_list): if network_hashes: p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) if len(names) > 0: - shared.log.info(f'Applying LoRA: {names} patch={t1-t0:.2f} load={t2-t1:.2f}') + shared.log.info(f'LoRA apply: {names} patch={t1-t0:.2f} load={t2-t1:.2f}') elif self.active: self.active = False def deactivate(self, p): if shared.backend == shared.Backend.DIFFUSERS and hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"): if 'CLIP' in shared.sd_model.text_encoder.__class__.__name__ and not (shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx"): + if shared.opts.lora_fuse_diffusers: + shared.sd_model.unfuse_lora() shared.sd_model.unload_lora_weights() if not self.active and getattr(networks, "originals", None ) is not None: networks.originals.undo() # remove patches diff --git a/extensions-builtin/Lora/lora_convert.py b/extensions-builtin/Lora/lora_convert.py index fb314f258..65f6a0adb 100644 --- a/extensions-builtin/Lora/lora_convert.py +++ b/extensions-builtin/Lora/lora_convert.py @@ -1,9 +1,11 @@ -from typing import Dict +import os import re import bisect +from typing import Dict from modules import shared +debug = os.environ.get('SD_LORA_DEBUG', None) is not None suffix_conversion = { "attentions": {}, "resnets": { @@ -144,12 +146,13 @@ def diffusers(self, key): map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys.sort() search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "") - position = bisect.bisect_right(map_keys, search_key) map_key = map_keys[position - 1] if search_key.startswith(map_key): - key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft","lora") # pylint: disable=unsubscriptable-object + key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if debug and sd_module is None: + raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}") return key, sd_module def __call__(self, key): diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py index 983d80240..9ea6a10c3 100644 --- a/extensions-builtin/Lora/lora_patches.py +++ b/extensions-builtin/Lora/lora_patches.py @@ -1,10 +1,7 @@ -import os import torch import networks from modules import patches, shared -# OpenVINO only works with Diffusers LoRa loading -force_lora_diffusers = os.environ.get('SD_LORA_DIFFUSERS', None) is not None class LoraPatches: def __init__(self): @@ -21,7 +18,7 @@ def __init__(self): self.MultiheadAttention_load_state_dict = None def apply(self): - if self.active or force_lora_diffusers: + if self.active or shared.opts.lora_force_diffusers: return self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) @@ -39,7 +36,7 @@ def apply(self): self.active = True def undo(self): - if not self.active or force_lora_diffusers: + if not self.active or shared.opts.lora_force_diffusers: return self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128 self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') # pylint: disable=E1128 diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc9..1679a0ce6 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d7c158770..e5828daf3 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -83,7 +83,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk self.te_multiplier = 1.0 - self.unet_multiplier = 1.0 + self.unet_multiplier = [1.0] * 3 self.dyn_dim = None self.modules = {} self.mtime = None @@ -112,8 +112,14 @@ def __init__(self, net: Network, weights: NetworkWeights): def multiplier(self): if 'transformer' in self.sd_key[:20]: return self.network.te_multiplier + if "down_blocks" in self.sd_key: + return self.network.unet_multiplier[0] + if "mid_block" in self.sd_key: + return self.network.unet_multiplier[1] + if "up_blocks" in self.sd_key: + return self.network.unet_multiplier[2] else: - return self.network.unet_multiplier + return self.network.unet_multiplier[0] def calc_scale(self): if self.scale is not None: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96..233791712 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -16,12 +16,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.weight = weights.w.get("diff") self.ex_bias = weights.w.get("diff_b") - def calc_updown(self, orig_weight): + def calc_updown(self, target): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(target.device, dtype=target.dtype) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(target.device, dtype=target.dtype) else: ex_bias = None - return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) + return self.finalize_updown(updown, target, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py new file mode 100644 index 000000000..ce6ceaa1b --- /dev/null +++ b/extensions-builtin/Lora/network_glora.py @@ -0,0 +1,30 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, target): # pylint: disable=arguments-differ + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) + output_shape = [w1a.size(0), w1b.size(1)] + updown = (w2b @ w1b) + ((target @ w2a) @ w1a) + return self.finalize_updown(updown, target, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 78e8c1569..0feda761e 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -22,15 +22,15 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.t1 = weights.w.get("hada_t1") self.t2 = weights.w.get("hada_t2") - def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + def calc_updown(self, target): + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(target.device, dtype=target.dtype) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -38,9 +38,9 @@ def calc_updown(self, orig_weight): output_shape += w1b.shape[2:] updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(target.device, dtype=target.dtype) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) updown = updown1 * updown2 - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, target, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index f8d86926f..cb39df228 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -15,12 +15,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.w = weights.w["weight"] self.on_input = weights.w["on_input"].item() - def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) - output_shape = [w.size(0), orig_weight.size(1)] + def calc_updown(self, target): + w = self.w.to(target.device, dtype=target.dtype) + output_shape = [w.size(0), target.size(1)] if self.on_input: output_shape.reverse() else: w = w.reshape(-1, 1) - updown = orig_weight * w - return self.finalize_updown(updown, orig_weight, output_shape) + updown = target * w + return self.finalize_updown(updown, target, output_shape) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 426d64308..20387efee 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -32,26 +32,26 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim self.t2 = weights.w.get("lokr_t2") - def calc_updown(self, orig_weight): + def calc_updown(self, target): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(target.device, dtype=target.dtype) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(target.device, dtype=target.dtype) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] - if len(orig_weight.shape) == 4: - output_shape = orig_weight.shape + if len(target.shape) == 4: + output_shape = target.shape updown = make_kron(output_shape, w1, w2) - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, target, output_shape) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 5dcb05322..8c2c4c8a5 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -51,20 +51,20 @@ def create_module(self, weights, key, none_ok=False): module.weight.requires_grad_(False) return module - def calc_updown(self, orig_weight): # pylint: disable=W0237 - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + def calc_updown(self, target): # pylint: disable=W0237 + up = self.up_model.weight.to(target.device, dtype=target.dtype) + down = self.down_model.weight.to(target.device, dtype=target.dtype) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(target.device, dtype=target.dtype) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: if len(down.shape) == 4: output_shape += down.shape[2:] updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, target, output_shape) def forward(self, x, y): self.up_model.to(device=devices.device) diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index a291fbad3..f327b9754 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -14,11 +14,11 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.w_norm = weights.w.get("w_norm") self.b_norm = weights.w.get("b_norm") - def calc_updown(self, orig_weight): + def calc_updown(self, target): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(target.device, dtype=target.dtype) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(target.device, dtype=target.dtype) else: ex_bias = None - return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) + return self.finalize_updown(updown, target, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 6d350671a..6cadc36d0 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,49 +1,85 @@ import torch -import diffusers.models.lora as diffusers_lora import network -from modules import devices +from lyco_helpers import factorization +from einops import rearrange + class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - """ - weights.w.items() + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): + return NetworkModuleOFT(net, weights) + + return None + +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + + super().__init__(net, weights) + + self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] - alpha : tensor(0.0010, dtype=torch.bfloat16) - oft_blocks : tensor([[[ 0.0000e+00, 1.4400e-04, 1.7319e-03, ..., -8.8882e-04, - 5.7373e-03, -4.4250e-03], - [-1.4400e-04, 0.0000e+00, 8.6594e-04, ..., 1.5945e-03, - -8.5449e-04, 1.9684e-03], ...etc... - , dtype=torch.bfloat16)""" + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): - module = NetworkModuleOFT(net, weights) - return module - else: - return None + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) + self.alpha = weights.w["alpha"] # alpha is constraint + self.dim = self.oft_blocks.shape[0] # lora dim + # LyCORIS + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported -class NetworkModuleOFT(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) + if is_linear: + self.out_dim = self.sd_module.out_features + elif is_conv: + self.out_dim = self.sd_module.out_channels + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim - self.weights = weights.w.get("oft_blocks").to(device=devices.device) - self.dim = self.weights.shape[0] # num blocks - self.alpha = self.multiplier() - self.block_size = self.weights.shape[-1] - - def get_weight(self): - block_Q = self.weights - self.weights.transpose(1, 2) - I = torch.eye(self.block_size, device=devices.device).unsqueeze(0).repeat(self.dim, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.alpha * block_R + (1 - self.alpha) * I - R = torch.block_diag(*block_R_weighted) - return R - - def calc_updown(self, orig_weight): - R = self.get_weight().to(device=devices.device, dtype=orig_weight.dtype) - if orig_weight.dim() == 4: - updown = torch.einsum("oihw, op -> pihw", orig_weight, R) * self.calc_scale() + if self.is_kohya: + self.constraint = self.alpha * self.out_dim + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim else: - updown = torch.einsum("oi, op -> pi", orig_weight, R) * self.calc_scale() + self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + def calc_updown(self, target): + oft_blocks = self.oft_blocks.to(target.device, dtype=target.dtype) + eye = torch.eye(self.block_size, device=target.device) + constraint = self.constraint.to(target.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()).to(target.device) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + mat1 = eye + block_Q + mat2 = (eye - block_Q).float().inverse() + oft_blocks = torch.matmul(mat1, mat2) + + R = oft_blocks.to(target.device, dtype=target.dtype) + + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(target, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - return self.finalize_updown(updown, orig_weight, orig_weight.shape) + updown = merged_weight.to(target.device, dtype=target.dtype) - target + output_shape = target.shape + return self.finalize_updown(updown, target, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 40f19bde1..b0128eb43 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List import os import re import time @@ -11,18 +11,19 @@ import network_lokr import network_full import network_norm +import network_glora import lora_convert import torch import diffusers.models.lora from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, sd_hijack -debug = os.environ.get('SD_LORA_DEBUG', None) +debug = os.environ.get('SD_LORA_DEBUG', None) is not None originals: lora_patches.LoraPatches = None extra_network_lora = None available_networks = {} available_network_aliases = {} -loaded_networks = [] +loaded_networks: List[network.Network] = [] timer = { 'load': 0, 'apply': 0, 'restore': 0 } # networks_in_memory = {} lora_cache = {} @@ -37,6 +38,7 @@ network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), ] convert_diffusers_name_to_compvis = lora_convert.convert_diffusers_name_to_compvis # supermerger compatibility item @@ -74,17 +76,17 @@ def assign_network_names_to_compvis_modules(sd_model): sd_model.network_layer_mapping = network_layer_mapping -def load_diffusers(name, network_on_disk, lora_scale=1.0): +def load_diffusers(name, network_on_disk, lora_scale=1.0) -> network.Network: t0 = time.time() cached = lora_cache.get(name, None) # if debug: - shared.log.debug(f'LoRA load: name={name} file={network_on_disk.filename} type=diffusers {"cached" if cached else ""}') + shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=diffusers {"cached" if cached else ""} fuse={shared.opts.lora_fuse_diffusers}') if cached is not None: return cached if shared.backend != shared.Backend.DIFFUSERS: return None shared.sd_model.load_lora_weights(network_on_disk.filename) - if shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx": + if shared.opts.lora_fuse_diffusers: shared.sd_model.fuse_lora(lora_scale=lora_scale) net = network.Network(name, network_on_disk) net.mtime = os.path.getmtime(network_on_disk.filename) @@ -94,11 +96,11 @@ def load_diffusers(name, network_on_disk, lora_scale=1.0): return net -def load_network(name, network_on_disk): +def load_network(name, network_on_disk) -> network.Network: t0 = time.time() cached = lora_cache.get(name, None) if debug: - shared.log.debug(f'LoRA load: name={name} file={network_on_disk.filename} {"cached" if cached else ""}') + shared.log.debug(f'LoRA load: name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}') if cached is not None: return cached net = network.Network(name, network_on_disk) @@ -109,7 +111,16 @@ def load_network(name, network_on_disk): matched_networks = {} convert = lora_convert.KeyConvert() for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + parts = key_network.split('.') + if len(parts) > 5: # messy handler for diffusers peft lora + key_network_without_network_parts = '_'.join(parts[:-2]) + if not key_network_without_network_parts.startswith('lora_'): + key_network_without_network_parts = 'lora_' + key_network_without_network_parts + network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up') + else: + key_network_without_network_parts, network_part = key_network.split(".", 1) + # if debug: + # shared.log.debug(f'LoRA load: name="{name}" full={key_network} network={network_part} key={key_network_without_network_parts}') key, sd_module = convert(key_network_without_network_parts) if sd_module is None: keys_failed_to_match[key_network] = key @@ -124,12 +135,15 @@ def load_network(name, network_on_disk): if net_module is not None: break if net_module is None: - raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") - net.modules[key] = net_module - if keys_failed_to_match: - shared.log.warning(f"LoRA unmatched keys: file={network_on_disk.filename} keys={len(keys_failed_to_match)}") + shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}') + else: + net.modules[key] = net_module + if len(keys_failed_to_match) > 0: + shared.log.warning(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}") if debug: - shared.log.debug(f"LoRA unmatched keys: file={network_on_disk.filename} keys={keys_failed_to_match}") + shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={keys_failed_to_match}") + elif debug: + shared.log.debug(f"LoRA file={network_on_disk.filename} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}") lora_cache[name] = net t1 = time.time() timer['load'] += t1 - t0 @@ -167,10 +181,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): net = None if network_on_disk is not None: + if debug: + shared.log.debug(f'LoRA load start: name="{name}" file="{network_on_disk.filename}"') try: if recompile_model: shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else 1.0}") - if shared.backend == shared.Backend.DIFFUSERS and (os.environ.get('SD_LORA_DIFFUSERS', None) is not None): # OpenVINO only works with Diffusers LoRa loading. + if shared.backend == shared.Backend.DIFFUSERS and shared.opts.lora_force_diffusers: # OpenVINO only works with Diffusers LoRa loading. # or getattr(network_on_disk, 'shorthash', '').lower() == 'aaebf6360f7d' # sd15-lcm # or getattr(network_on_disk, 'shorthash', '').lower() == '3d18b05e4f56' # sdxl-lcm # or getattr(network_on_disk, 'shorthash', '').lower() == '813ea5fb1c67' # turbo sdxl-turbo @@ -186,7 +202,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No network_on_disk.read_hash() if net is None: failed_to_load_networks.append(name) - shared.log.error(f"LoRA unknown: network={name}") + shared.log.error(f"LoRA unknown type: network={name}") continue net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 @@ -269,10 +285,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: + # default workflow where module is known and has weights module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): try: - with torch.no_grad(): + with devices.inference_context(): updown, ex_bias = module.calc_updown(self.weight) if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 @@ -284,17 +301,21 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn else: self.bias += ex_bias except RuntimeError as e: - if debug: - shared.log.debug(f"LoRA apply weight network={net.name} layer={network_layer_name} {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + if debug: + module_name = net.modules.get(network_layer_name, None) + shared.log.error(f"LoRA apply weight name={net.name} module={module_name} layer={network_layer_name} {e}") + errors.display(e, 'LoRA apply weight') + raise RuntimeError('LoRA apply weight') from e continue + # alternative workflow looking at _*_proj layers module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) module_v = net.modules.get(network_layer_name + "_v_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: try: - with torch.no_grad(): + with devices.inference_context(): updown_q, _ = module_q.calc_updown(self.in_proj_weight) updown_k, _ = module_k.calc_updown(self.in_proj_weight) updown_v, _ = module_v.calc_updown(self.in_proj_weight) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index bd408b327..cee5ad29e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -19,7 +19,6 @@ def create_item(self, name): try: path, _ext = os.path.splitext(l.filename) name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0] - if shared.backend == shared.Backend.ORIGINAL: if l.sd_version == network.SdVersion.SDXL: return None diff --git a/extensions-builtin/sd-extension-chainner b/extensions-builtin/sd-extension-chainner index 1cdc8578a..e48020d03 160000 --- a/extensions-builtin/sd-extension-chainner +++ b/extensions-builtin/sd-extension-chainner @@ -1 +1 @@ -Subproject commit 1cdc8578a3c0538177c3ac13721b9764f8782c8f +Subproject commit e48020d035b0d4daffec37a2bbce61c3ea04798c diff --git a/extensions-builtin/sd-extension-system-info b/extensions-builtin/sd-extension-system-info index 1841cf762..f93480869 160000 --- a/extensions-builtin/sd-extension-system-info +++ b/extensions-builtin/sd-extension-system-info @@ -1 +1 @@ -Subproject commit 1841cf7627d3991d7ba3dfe8b86a14b18f78f933 +Subproject commit f934808698523465e7abb2d1fb0d290065def776 diff --git a/extensions-builtin/sd-webui-agent-scheduler b/extensions-builtin/sd-webui-agent-scheduler index a8d527b26..435dd9bde 160000 --- a/extensions-builtin/sd-webui-agent-scheduler +++ b/extensions-builtin/sd-webui-agent-scheduler @@ -1 +1 @@ -Subproject commit a8d527b269358595bd7f9c839f266b1db5bea468 +Subproject commit 435dd9bdec0fdb22f73f645de76b55684ed16e67 diff --git a/extensions-builtin/sd-webui-controlnet b/extensions-builtin/sd-webui-controlnet index 10bd9b25f..dc0d590fa 160000 --- a/extensions-builtin/sd-webui-controlnet +++ b/extensions-builtin/sd-webui-controlnet @@ -1 +1 @@ -Subproject commit 10bd9b25f62deab9acb256301bbf3363c42645e7 +Subproject commit dc0d590faf3240746212477a95c3026aa7e9584f diff --git a/extensions-builtin/stable-diffusion-webui-images-browser b/extensions-builtin/stable-diffusion-webui-images-browser index 08fc2647f..27fe4a713 160000 --- a/extensions-builtin/stable-diffusion-webui-images-browser +++ b/extensions-builtin/stable-diffusion-webui-images-browser @@ -1 +1 @@ -Subproject commit 08fc2647f1fe413699612df923b5f495d26853ef +Subproject commit 27fe4a713d883436049ed1ed9e1642f8eb4fd924 diff --git a/html/emerald-paradise.jpg b/html/emerald-paradise.jpg new file mode 100644 index 000000000..73a4b7a5f Binary files /dev/null and b/html/emerald-paradise.jpg differ diff --git a/html/locale_en.json b/html/locale_en.json index 27d4beaa7..91789b742 100644 --- a/html/locale_en.json +++ b/html/locale_en.json @@ -43,17 +43,18 @@ "tabs": [ {"id":"","label":"Text","localized":"","hint":"Create image from text"}, {"id":"","label":"Image","localized":"","hint":"Create image from image"}, + {"id":"","label":"Control","localized":"","hint":"Create image with additional control"}, {"id":"","label":"Process","localized":"","hint":"Process existing image"}, - {"id":"","label":"Train","localized":"","hint":"Run training or model merging"}, + {"id":"","label":"Interrogate","localized":"","hint":"Run interrogate to get description of your image"}, + {"id":"","label":"Train","localized":"","hint":"Run training"}, {"id":"","label":"Models","localized":"","hint":"Convert or merge your models"}, - {"id":"","label":"Interrogator","localized":"","hint":"Run interrogate to get description of your image"}, - {"id":"","label":"System Info","localized":"","hint":"System information"}, {"id":"","label":"Agent Scheduler","localized":"","hint":"Enqueue your generate requests and run them in the background"}, {"id":"","label":"Image Browser","localized":"","hint":"Browse through your generated image database"}, {"id":"","label":"System","localized":"","hint":"System settings and information"}, + {"id":"","label":"System Info","localized":"","hint":"System information"}, {"id":"","label":"Settings","localized":"","hint":"Application settings"}, - {"id":"","label":"Extensions","localized":"","hint":"Application extensions"}, - {"id":"","label":"Script","localized":"","hint":"Addtional scripts to be used"} + {"id":"","label":"Script","localized":"","hint":"Addtional scripts to be used"}, + {"id":"","label":"Extensions","localized":"","hint":"Application extensions"} ], "action panel": [ {"id":"","label":"Generate","localized":"","hint":"Start processing"}, @@ -226,10 +227,10 @@ {"id":"","label":"Inpaint batch input directory","localized":"","hint":""}, {"id":"","label":"Inpaint batch output directory","localized":"","hint":""}, {"id":"","label":"Inpaint batch mask directory","localized":"","hint":""}, - {"id":"","label":"Resize fixed","localized":"","hint":"Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio"}, - {"id":"","label":"Crop and resize","localized":"","hint":"Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out"}, - {"id":"","label":"Resize and fill","localized":"","hint":"Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors"}, - {"id":"","label":"Latent upscale","localized":"","hint":""}, + {"id":"","label":"Fixed","localized":"","hint":"Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio"}, + {"id":"","label":"Crop","localized":"","hint":"Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out"}, + {"id":"","label":"Fill","localized":"","hint":"Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors"}, + {"id":"","label":"Latent","localized":"","hint":""}, {"id":"","label":"Mask blur","localized":"","hint":"How much to blur the mask before processing, in pixels"}, {"id":"","label":"Mask transparency","localized":"","hint":""}, {"id":"","label":"Inpaint masked","localized":"","hint":""}, @@ -484,8 +485,8 @@ {"id":"","label":"Dark","localized":"","hint":""}, {"id":"","label":"Light","localized":"","hint":""}, {"id":"","label":"Show grid in results","localized":"","hint":""}, - {"id":"","label":"For inpainting, include the greyscale mask in results","localized":"","hint":""}, - {"id":"","label":"For inpainting, include masked composite in results","localized":"","hint":""}, + {"id":"","label":"Inpainting include greyscale mask in results","localized":"","hint":""}, + {"id":"","label":"Inpainting include masked composite in results","localized":"","hint":""}, {"id":"","label":"Do not change selected model when reading generation parameters","localized":"","hint":""}, {"id":"","label":"Send seed when sending prompt or image to other interface","localized":"","hint":""}, {"id":"","label":"Send size when sending prompt or image to another interface","localized":"","hint":""}, @@ -500,7 +501,7 @@ {"id":"","label":"Show previews of all images generated in a batch as a grid","localized":"","hint":""}, {"id":"","label":"Play a sound when images are finished generating","localized":"","hint":""}, {"id":"","label":"Path to notification sound","localized":"","hint":""}, - {"id":"","label":"Live preview display period","localized":"","hint":""}, + {"id":"","label":"Live preview display period","localized":"","hint":"Request preview image every n steps, set to 0 to disable"}, {"id":"","label":"Full VAE","localized":"","hint":""}, {"id":"","label":"Approximate","localized":"","hint":"Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality"}, {"id":"","label":"Simple","localized":"","hint":"Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality"}, diff --git a/html/orchid-dreams.jpg b/html/orchid-dreams.jpg new file mode 100644 index 000000000..8a62eeb0a Binary files /dev/null and b/html/orchid-dreams.jpg differ diff --git a/html/reference.json b/html/reference.json index a3ff809ed..1ab1291f7 100644 --- a/html/reference.json +++ b/html/reference.json @@ -1,13 +1,46 @@ { + "DreamShaper SD 1.5 v8": { + "path": "dreamshaper_8.safetensors@https://civitai.com/api/download/models/128713", + "desc": "Showcase finetuned model based on Stable diffusion 1.5", + "preview": "dreamshaper_8.jpg", + "original": true + }, + "DreamShaper SD XL Turbo": { + "path": "dreamshaperXL_turboDpmppSDE.safetensors@https://civitai.com/api/download/models/251662", + "desc": "Showcase finetuned model based on Stable diffusion XL", + "preview": "dreamshaperXL_turboDpmppSDE.jpg" + }, + "Juggernaut Reborn": { + "path": "juggernaut_reborn.safetensors@https://civitai.com/api/download/models/274039", + "desc": "Showcase finetuned model based on Stable diffusion 1.5", + "preview": "juggernaut_reborn.jpg", + "original": true + }, + "Juggernaut XL v7 RunDiffusion": { + "path": "juggernautXL_v7Rundiffusion.safetensors@https://civitai.com/api/download/models/240840", + "desc": "Showcase finetuned model based on Stable diffusion XL", + "preview": "juggernautXL_v7Rundiffusion.jpg" + }, "RunwayML SD 1.5": { "path": "runwayml/stable-diffusion-v1-5", + "alt": "v1-5-pruned-emaonly.safetensors@https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors?download=true", "desc": "Stable Diffusion 1.5 is the base model all other 1.5 checkpoint were trained from. It's a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The Stable-Diffusion-v1-5 checkpoint was initialized with the weights of the Stable-Diffusion-v1-2 checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512.", - "preview": "runwayml--stable-diffusion-v1-5.jpg" + "preview": "runwayml--stable-diffusion-v1-5.jpg", + "original": true + }, + "StabilityAI SD 2.1 EMA": { + "path": "stabilityai/stable-diffusion-2-1-base", + "alt": "v2-1_512-ema-pruned.safetensors@https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors?download=true", + "desc": "This stable-diffusion-2-1-base model fine-tunes stable-diffusion-2-base (512-base-ema.ckpt) with 220k extra steps taken", + "preview": "stabilityai--stable-diffusion-2.1-base.jpg", + "original": true }, - "StabilityAI SD 2.1": { + "StabilityAI SD 2.1 V": { "path": "stabilityai/stable-diffusion-2-1-base", - "desc": "This stable-diffusion-2-1 model is fine-tuned from stable-diffusion-2 (768-v-ema.ckpt) with an additional 55k steps on the same dataset. Improvement over base 1.5 model, but never really took off.", - "preview": "stabilityai--stable-diffusion-2.1-base.jpg" + "alt": "v2-1_768-ema-pruned.safetensors@https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors?download=true", + "desc": "This stable-diffusion-2 model is resumed from stable-diffusion-2-base (512-base-ema.ckpt) and trained for 150k steps using a v-objective on the same dataset. Resumed for another 140k steps on 768x768 images", + "preview": "stabilityai--stable-diffusion-2.1-base.jpg", + "original": true }, "StabilityAI SD-XL 1.0 Base": { "path": "stabilityai/stable-diffusion-xl-base-1.0", @@ -16,8 +49,10 @@ }, "StabilityAI SD 2.1 Turbo": { "path": "stabilityai/sd-turbo", + "alt": "sd_turbo.safetensors@https://huggingface.co/stabilityai/sd-turbo/resolve/main/sd_turbo.safetensors?download=true", "desc": "SD-Turbo is a distilled version of Stable Diffusion 2.1, trained for real-time synthesis. SD-Turbo is based on a novel training method called Adversarial Diffusion Distillation (ADD) (see the technical report), which allows sampling large-scale foundational image diffusion models in 1 to 4 steps at high image quality. This approach uses score distillation to leverage large-scale off-the-shelf image diffusion models as a teacher signal and combines this with an adversarial loss to ensure high image fidelity even in the low-step regime of one or two sampling steps.", - "preview": "stabilityai--sd-turbo.jpg" + "preview": "stabilityai--sd-turbo.jpg", + "original": true }, "StabilityAI SD-XL Turbo": { "path": "stabilityai/sdxl-turbo", @@ -34,6 +69,11 @@ "desc": "(SVD) Image-to-Video is a latent diffusion model trained to generate short video clips from an image conditioning. This model was trained to generate 25 frames at resolution 576x1024 given a context frame of the same size, finetuned from SVD Image-to-Video [14 frames]. We also finetune the widely used f8-decoder for temporal consistency.", "preview": "stabilityai--stable-video-diffusion-img2vid-xt.jpg" }, + "Segmind Vega": { + "path": "segmind/Segmind-Vega", + "desc": "The Segmind-Vega Model is a distilled version of the Stable Diffusion XL (SDXL), offering a remarkable 70% reduction in size and an impressive 100% speedup while retaining high-quality text-to-image generation capabilities. Trained on diverse datasets, including Grit and Midjourney scrape data, it excels at creating a wide range of visual content based on textual prompts. Employing a knowledge distillation strategy, Segmind-Vega leverages the teachings of several expert models, including SDXL, ZavyChromaXL, and JuggernautXL, to combine their strengths and produce compelling visual outputs.", + "preview": "segmind--Segmind-Vega.jpg" + }, "Segmind SSD-1B": { "path": "segmind/SSD-1B", "desc": "The Segmind Stable Diffusion Model (SSD-1B) offers a compact, efficient, and distilled version of the SDXL model. At 50% smaller and 60% faster than Stable Diffusion XL (SDXL), it provides quick and seamless performance without sacrificing image quality.", @@ -79,14 +119,49 @@ "desc": "Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, Kandinsky 3.0 incorporates more data and specifically related to Russian culture, which allows to generate pictures related to Russin culture. Furthermore, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.", "preview": "kandinsky-community--kandinsky-3.jpg" }, + "Playground v1": { + "path": "playgroundai/playground-v1", + "desc": "Playground v1 is a latent diffusion model that improves the overall HDR quality to get more stunning images.", + "preview": "playgroundai--playground-v1.jpg" + }, + "Playground v2 256": { + "path": "playgroundai/playground-v2-256px-base", + "desc": "Playground v2 is a diffusion-based text-to-image generative model. The model was trained from scratch by the research team at Playground. Images generated by Playground v2 are favored 2.5 times more than those produced by Stable Diffusion XL, according to Playground’s user study.", + "preview": "playgroundai--playground-v2-256px-base.jpg" + }, + "Playground v2 512": { + "path": "playgroundai/playground-v2-512px-base", + "desc": "Playground v2 is a diffusion-based text-to-image generative model. The model was trained from scratch by the research team at Playground. Images generated by Playground v2 are favored 2.5 times more than those produced by Stable Diffusion XL, according to Playground’s user study.", + "preview": "playgroundai--playground-v2-512px-base.jpg" + }, + "Playground v2 1024": { + "path": "playgroundai/playground-v2-1024px-aesthetic", + "desc": "Playground v2 is a diffusion-based text-to-image generative model. The model was trained from scratch by the research team at Playground. Images generated by Playground v2 are favored 2.5 times more than those produced by Stable Diffusion XL, according to Playground’s user study.", + "preview": "playgroundai--playground-v2-1024px-aesthetic.jpg" + }, "DeepFloyd IF Medium": { "path": "DeepFloyd/IF-I-M-v1.0", "desc": "DeepFloyd-IF is a pixel-based text-to-image triple-cascaded diffusion model, that can generate pictures with new state-of-the-art for photorealism and language understanding. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID-30K score of 6.66 on the COCO dataset. It is modular and composed of frozen text mode and three pixel cascaded diffusion modules, each designed to generate images of increasing resolution: 64x64, 256x256, and 1024x1024.", "preview": "DeepFloyd--IF-I-M-v1.0.jpg" }, + "aMUSEd 256": { + "path": "amused/amused-256", + "desc": "Amused is a lightweight text to image model based off of the muse architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.", + "preview": "amused--amused-256.jpg" + }, + "aMUSEd 512": { + "path": "amused/amused-512", + "desc": "Amused is a lightweight text to image model based off of the muse architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.", + "preview": "amused--amused-512.jpg" + }, "Tsinghua UniDiffuser": { "path": "thu-ml/unidiffuser-v1", "desc": "UniDiffuser is a unified diffusion framework to fit all distributions relevant to a set of multi-modal data in one transformer. UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead.\nSpecifically, UniDiffuser employs a variation of transformer, called U-ViT, which parameterizes the joint noise prediction network. Other components perform as encoders and decoders of different modalities, including a pretrained image autoencoder from Stable Diffusion, a pretrained image ViT-B/32 CLIP encoder, a pretrained text ViT-L CLIP encoder, and a GPT-2 text decoder finetuned by ourselves.", "preview": "thu-ml--unidiffuser-v1.jpg" + }, + "SalesForce BLIP-Diffusion": { + "path": "salesforce/blipdiffusion", + "desc": "BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation.", + "preview": "salesforce--blipdiffusion.jpg" } } \ No newline at end of file diff --git a/html/timeless-beige.jpg b/html/timeless-beige.jpg new file mode 100644 index 000000000..e8c591379 Binary files /dev/null and b/html/timeless-beige.jpg differ diff --git a/html/xmas-control.jpg b/html/xmas-control.jpg new file mode 100644 index 000000000..ce86b0ee7 Binary files /dev/null and b/html/xmas-control.jpg differ diff --git a/html/xmas-default.jpg b/html/xmas-default.jpg new file mode 100644 index 000000000..33633d886 Binary files /dev/null and b/html/xmas-default.jpg differ diff --git a/installer.py b/installer.py index 413463fbb..5972f66c0 100644 --- a/installer.py +++ b/installer.py @@ -67,6 +67,7 @@ def emit(self, record): def get(self): return self.buffer + from functools import partial, partialmethod from logging.handlers import RotatingFileHandler from rich.theme import Theme from rich.logging import RichHandler @@ -78,6 +79,11 @@ def get(self): global log_file # pylint: disable=global-statement log_file = args.log + logging.TRACE = 25 + logging.addLevelName(logging.TRACE, 'TRACE') + logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE) + logging.trace = partial(logging.log, logging.TRACE) + level = logging.DEBUG if args.debug else logging.INFO log.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd` console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ @@ -357,7 +363,8 @@ def check_torch(): log.debug(f'Torch allowed: cuda={allow_cuda} rocm={allow_rocm} ipex={allow_ipex} diml={allow_directml} openvino={allow_openvino}') torch_command = os.environ.get('TORCH_COMMAND', '') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'none') - install('onnxruntime', 'onnxruntime', ignore=True) + if not installed('onnxruntime', quiet=True) and not installed('onnxruntime-gpu', quiet=True): # allow either + install('onnxruntime', 'onnxruntime', ignore=True) if torch_command != '': pass elif allow_cuda and (shutil.which('nvidia-smi') is not None or args.use_xformers or os.path.exists(os.path.join(os.environ.get('SystemRoot') or r'C:\Windows', 'System32', 'nvidia-smi.exe'))): @@ -424,18 +431,32 @@ def check_torch(): os.environ.setdefault('NEOReadDebugKeys', '1') os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') if "linux" in sys.platform: - torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/') - os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.13.0 intel-extension-for-tensorflow[gpu]') + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.1.0a0 torchvision==0.16.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/') + os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.14.0 intel-extension-for-tensorflow[xpu]==2.14.0.1') + install(os.environ.get('MKL_PACKAGE', 'mkl==2024.0.0'), 'mkl') + install(os.environ.get('DPCPP_PACKAGE', 'mkl-dpcpp==2024.0.0'), 'mkl-dpcpp') else: - pytorch_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl' - torchvision_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl' - ipex_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl' + if sys.version_info[1] == 11: + pytorch_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/torch-2.1.0a0+cxx11.abi-cp311-cp311-win_amd64.whl' + torchvision_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/torchvision-0.16.0a0+cxx11.abi-cp311-cp311-win_amd64.whl' + ipex_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/intel_extension_for_pytorch-2.1.10+xpu-cp311-cp311-win_amd64.whl' + elif sys.version_info[1] == 10: + pytorch_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/torch-2.1.0a0+cxx11.abi-cp310-cp310-win_amd64.whl' + torchvision_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/torchvision-0.16.0a0+cxx11.abi-cp310-cp310-win_amd64.whl' + ipex_pip = 'https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.1.10%2Bxpu/intel_extension_for_pytorch-2.1.10+xpu-cp310-cp310-win_amd64.whl' + else: + pytorch_pip = 'torch==2.1.0a0' + torchvision_pip = 'torchvision==0.16.0a0' + ipex_pip = 'intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/' + install(os.environ.get('MKL_PACKAGE', 'mkl==2024.0.0'), 'mkl') + install(os.environ.get('DPCPP_PACKAGE', 'mkl-dpcpp==2024.0.0'), 'mkl-dpcpp') torch_command = os.environ.get('TORCH_COMMAND', f'{pytorch_pip} {torchvision_pip} {ipex_pip}') - install('openvino', 'openvino', ignore=True) + install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2023.2.0'), 'openvino', ignore=True) + install('nncf==2.7.0', 'nncf', ignore=True) install('onnxruntime-openvino', 'onnxruntime-openvino', ignore=True) elif allow_openvino and args.use_openvino: log.info('Using OpenVINO') - torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cpu') + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu') else: machine = platform.machine() if sys.platform == 'darwin': @@ -503,11 +524,10 @@ def check_torch(): if opts.get('cuda_compile_backend', '') == 'hidet': install('hidet', 'hidet') if args.use_openvino or opts.get('cuda_compile_backend', '') == 'openvino_fx': - uninstall('openvino-nightly') # TODO openvino: remove after people had enough time upgrading - install('openvino==2023.2.0', 'openvino') + install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2023.2.0'), 'openvino') + install('nncf==2.7.0', 'nncf') install('onnxruntime-openvino', 'onnxruntime-openvino', ignore=True) # TODO openvino: numpy version conflicts with tensorflow and doesn't support Python 3.11 os.environ.setdefault('PYTORCH_TRACING_MODE', 'TORCHFX') - os.environ.setdefault('SD_LORA_DIFFUSERS', '1') os.environ.setdefault('NEOReadDebugKeys', '1') os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') if args.profile: @@ -556,36 +576,6 @@ def install_packages(): print_profile(pr, 'Packages') -# clone required repositories -def install_repositories(): - """ - if args.profile: - pr = cProfile.Profile() - pr.enable() - def d(name): - return os.path.join(os.path.dirname(__file__), 'repositories', name) - log.info('Verifying repositories') - os.makedirs(os.path.join(os.path.dirname(__file__), 'repositories'), exist_ok=True) - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - stable_diffusion_commit = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', None) - clone(stable_diffusion_repo, d('stable-diffusion-stability-ai'), stable_diffusion_commit) - taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") - taming_transformers_commit = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', None) - clone(taming_transformers_repo, d('taming-transformers'), taming_transformers_commit) - k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - k_diffusion_commit = os.environ.get('K_DIFFUSION_COMMIT_HASH', '0455157') - clone(k_diffusion_repo, d('k-diffusion'), k_diffusion_commit) - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') - codeformer_commit = os.environ.get('CODEFORMER_COMMIT_HASH', "7a584fd") - clone(codeformer_repo, d('CodeFormer'), codeformer_commit) - blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') - blip_commit = os.environ.get('BLIP_COMMIT_HASH', None) - clone(blip_repo, d('BLIP'), blip_commit) - if args.profile: - print_profile(pr, 'Repositories') - """ - - # run extension installer def run_extension_installer(folder): path_installer = os.path.realpath(os.path.join(folder, "install.py")) @@ -614,7 +604,7 @@ def list_extensions_folder(folder, quiet=False): if disabled_extensions_all != 'none': return [] disabled_extensions = opts.get('disabled_extensions', []) - enabled_extensions = [x for x in os.listdir(folder) if x not in disabled_extensions and not x.startswith('.')] + enabled_extensions = [x for x in os.listdir(folder) if os.path.isdir(os.path.join(folder, x)) and x not in disabled_extensions and not x.startswith('.')] if not quiet: log.info(f'Extensions: enabled={enabled_extensions} {name}') return enabled_extensions @@ -758,6 +748,7 @@ def set_environment(): os.environ.setdefault('USE_TORCH', '1') os.environ.setdefault('UVICORN_TIMEOUT_KEEP_ALIVE', '60') os.environ.setdefault('KINETO_LOG_LEVEL', '3') + os.environ.setdefault('DO_NOT_TRACK', '1') os.environ.setdefault('HF_HUB_CACHE', opts.get('hfcache_dir', os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub'))) log.debug(f'Cache folder: {os.environ.get("HF_HUB_CACHE")}') if sys.platform == 'darwin': @@ -919,7 +910,7 @@ def add_args(parser): group.add_argument('--reset', default = os.environ.get("SD_RESET",False), action='store_true', help = "Reset main repository to latest version, default: %(default)s") group.add_argument('--upgrade', default = os.environ.get("SD_UPGRADE",False), action='store_true', help = "Upgrade main repository to latest version, default: %(default)s") group.add_argument('--requirements', default = os.environ.get("SD_REQUIREMENTS",False), action='store_true', help = "Force re-check of requirements, default: %(default)s") - group.add_argument('--quick', default = os.environ.get("SD_QUICK",False), action='store_true', help = "Run with startup sequence only, default: %(default)s") + group.add_argument('--quick', default = os.environ.get("SD_QUICK",False), action='store_true', help = "Bypass version checks, default: %(default)s") group.add_argument('--use-directml', default = os.environ.get("SD_USEDIRECTML",False), action='store_true', help = "Use DirectML if no compatible GPU is detected, default: %(default)s") group.add_argument("--use-openvino", default = os.environ.get("SD_USEOPENVINO",False), action='store_true', help="Use Intel OpenVINO backend, default: %(default)s") group.add_argument("--use-ipex", default = os.environ.get("SD_USEIPEX",False), action='store_true', help="Force use Intel OneAPI XPU backend, default: %(default)s") diff --git a/javascript/amethyst-nightfall.css b/javascript/amethyst-nightfall.css index b36ae8160..ea544dcc4 100644 --- a/javascript/amethyst-nightfall.css +++ b/javascript/amethyst-nightfall.css @@ -101,9 +101,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } #txt2img_gallery, #img2img_gallery, #extras_gallery { padding: 0; margin: 0; object-fit: contain; box-shadow: none; min-height: 0; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } diff --git a/javascript/black-orange.css b/javascript/black-orange.css index b35267daa..90e1cd8bd 100644 --- a/javascript/black-orange.css +++ b/javascript/black-orange.css @@ -117,9 +117,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } #txt2img_cfg_scale { min-width: 200px; } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } diff --git a/javascript/black-teal.css b/javascript/black-teal.css index b2cd0b984..07624a76e 100644 --- a/javascript/black-teal.css +++ b/javascript/black-teal.css @@ -51,7 +51,7 @@ input[type=range]::-moz-range-thumb { box-shadow: 2px 2px 3px #111111 !important ::-webkit-scrollbar { width: 12px; } ::-webkit-scrollbar-track { background: #333333; } ::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; } -div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; } +div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; } /* gradio style classes */ fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; } @@ -78,7 +78,7 @@ svg.feather.feather-image, .feather .feather-image { display: none } .px-4 { padding-lefT: 1rem; padding-right: 1rem; } .py-6 { padding-bottom: 0; } .tabs { background-color: var(--background-color); } -.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; } +.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.7rem; } .tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; } .label-wrap { margin: 8px 0px 4px 0px; } .gradio-button.tool { border: none; background: none; box-shadow: none; filter: hue-rotate(340deg) saturate(0.5); } @@ -118,11 +118,7 @@ svg.feather.feather-image, .feather .feather-image { display: none } #txt2img_cfg_scale { min-width: 200px; } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8px !important; } - #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } #txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; } diff --git a/javascript/control.js b/javascript/control.js new file mode 100644 index 000000000..3690c491a --- /dev/null +++ b/javascript/control.js @@ -0,0 +1,17 @@ +function setupControlUI() { + const tabs = ['input', 'output', 'preview']; + for (const tab of tabs) { + const btn = gradioApp().getElementById(`control-${tab}-button`); + if (!btn) continue; // eslint-disable-line no-continue + btn.style.cursor = 'pointer'; + btn.onclick = () => { + const t = gradioApp().getElementById(`control-tab-${tab}`); + t.style.display = t.style.display === 'none' ? 'block' : 'none'; + const c = gradioApp().getElementById(`control-${tab}-column`); + c.style.flexGrow = c.style.flexGrow === '0' ? '9' : '0'; + }; + } + log('initControlUI'); +} + +onUiLoaded(setupControlUI); diff --git a/javascript/dragDrop.js b/javascript/dragDrop.js index bc7777102..73c0c33c3 100644 --- a/javascript/dragDrop.js +++ b/javascript/dragDrop.js @@ -48,9 +48,11 @@ window.document.addEventListener('dragover', (e) => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); if (!imgWrap && target.placeholder && target.placeholder.indexOf('Prompt') === -1) return; - e.stopPropagation(); - e.preventDefault(); - e.dataTransfer.dropEffect = 'copy'; + if ((e.dataTransfer?.files?.length || 0) > 0) { + e.stopPropagation(); + e.preventDefault(); + e.dataTransfer.dropEffect = 'copy'; + } }); window.document.addEventListener('drop', (e) => { @@ -59,10 +61,11 @@ window.document.addEventListener('drop', (e) => { if (target.placeholder.indexOf('Prompt') === -1) return; const imgWrap = target.closest('[data-testid="image"]'); if (!imgWrap) return; - e.stopPropagation(); - e.preventDefault(); - const { files } = e.dataTransfer; - dropReplaceImage(imgWrap, files); + if ((e.dataTransfer?.files?.length || 0) > 0) { + e.stopPropagation(); + e.preventDefault(); + dropReplaceImage(imgWrap, e.dataTransfer.files); + } }); window.addEventListener('paste', (e) => { diff --git a/javascript/emerald-paradise.css b/javascript/emerald-paradise.css new file mode 100644 index 000000000..9694a153d --- /dev/null +++ b/javascript/emerald-paradise.css @@ -0,0 +1,297 @@ +/* generic html tags */ +:root, .light, .dark { + --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif; + --font-mono: 'ui-monospace', 'Consolas', monospace; + --font-size: 16px; + --primary-100: #1e2223; /* bg color*/ + --primary-200: #242a2c; /* drop down menu/ prompt window fill*/ + --primary-300: #0a0c0e; /* black */ + --primary-400: #2a302c; /* small buttons*/ + --primary-500: #4b695d; /* main accent color green*/ + --primary-700: #273538; /* extension box fill*/ + --primary-800: #d15e84; /* pink(hover accent)*/ + --highlight-color: var(--primary-500); + --inactive-color: var(--primary--800); + --body-text-color: var(--neutral-100); + --body-text-color-subdued: var(--neutral-300); + --background-color: var(--primary-100); + --background-fill-primary: var(--input-background-fill); + --input-padding: 8px; + --input-background-fill: var(--primary-200); + --input-shadow: none; + --button-secondary-text-color: white; + --button-secondary-background-fill: var(--primary-400); + --button-secondary-background-fill-hover: var(--primary-700); + --block-title-text-color: var(--neutral-300); + --radius-sm: 1px; + --radius-lg: 6px; + --spacing-md: 4px; + --spacing-xxl: 8px; + --line-sm: 1.2em; + --line-md: 1.4em; +} + +html { font-size: var(--font-size); } +body, button, input, select, textarea { font-family: var(--font);} +button { font-size: 1.2rem; max-width: 400px; } +img { background-color: var(--background-color); } +input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; } +input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; } +input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; } +input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); } +::-webkit-scrollbar-track { background: #333333; } +::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; } +div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; } +div.compact { gap: 1em; } + +/* gradio style classes */ +fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; } +.border-2 { border-width: 0; } +.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; } +.bg-white { color: lightyellow; background-color: var(--inactive-color); } +.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px } +.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; } +.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; } +.gr-check-radio:checked { background-color: var(--highlight-color); } +.gr-compact { background-color: var(--background-color); } +.gr-form { border-width: 0; } +.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; } +.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; } +.gr-panel { background-color: var(--background-color); } +.eta-bar { display: none !important } +svg.feather.feather-image, .feather .feather-image { display: none } +.gap-2 { padding-top: 8px; } +.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; } +.output-html { line-height: 1.2rem; overflow-x: hidden; } +.output-html > div { margin-bottom: 8px; } +.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */ +.p-2 { padding: 0; } +.px-4 { padding-lefT: 1rem; padding-right: 1rem; } +.py-6 { padding-bottom: 0; } +.tabs { background-color: var(--background-color); } +.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; } +.tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; } +div.tab-nav button.selected {background-color: var(--button-primary-background-fill);} +#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;} +.label-wrap { background-color: #191919; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; } +.small-accordion .label-wrap { padding: 8px 0px 8px 0px; } +.small-accordion .label-wrap .icon { margin-right: 1em; } +.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);} +button.selected {background: var(--button-primary-background-fill);} +.center.boundedheight.flex {background-color: var(--input-background-fill);} +.compact {border-radius: var(--border-radius-lg);} +#logMonitorData {background-color: var(--input-background-fill);} +#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); } +#tab_extensions table, #tab_config table { width: 96vw; } +#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;} +#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);} + +/* automatic style classes */ +.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); } +.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; } +.gallery-item { box-shadow: none !important; } +.performance { color: #888; } +.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; } +.image-buttons { gap: 10px !important; justify-content: center; } +.image-buttons > button { max-width: 160px; } +.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) } +#system_row > button, #settings_row > button, #config_row > button { max-width: 190px; } + +/* gradio elements overrides */ +#div.gradio-container { overflow-x: hidden; } +#img2img_label_copy_to_img2img { font-weight: normal; } +#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; } +#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; } +#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; } +#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; } +#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; } +#quicksettings button {padding: 0 0.5em 0.1em 0.5em;} +#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; } +#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; } +#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; } +#settings > div.flex-wrap { width: 15em; } +#txt2img_cfg_scale { min-width: 200px; } +#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } +#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } +#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } +#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } +#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } + +#extras_upscale { margin-top: 10px } +#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } +#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; } +#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) } +#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; } +#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; } + +/* based on gradio built-in dark theme */ +:root, .light, .dark { + --body-background-fill: var(--background-color); + --color-accent-soft: var(--neutral-700); + --background-fill-secondary: none; + --border-color-accent: var(--background-color); + --border-color-primary: var(--background-color); + --link-text-color-active: var(--primary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --shadow-spread: 1px; + --block-background-fill: None; + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--primary-200); + --checkbox-background-color-focus: var(--primary-700); + --checkbox-background-color-hover: var(--primary-700); + --checkbox-background-color-selected: var(--primary-500); + --checkbox-border-color: transparent; + --checkbox-border-color-focus: var(--primary-800); + --checkbox-border-color-hover: var(--primary-800); + --checkbox-border-color-selected: var(--primary-800); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: None; + --checkbox-label-background-fill-hover: None; + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error-text-color: #f768b7; /*was ef4444*/ + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--background-color); + --input-border-color-focus: var(--primary-800); + --input-placeholder-color: var(--neutral-500); + --input-shadow-focus: None; + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--primary-300); + --table-odd-background-fill: var(--primary-200); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: var(--primary-500); + --button-primary-background-fill-hover: var(--primary-800); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --secondary-50: #eff6ff; + --secondary-100: #dbeafe; + --secondary-200: #bfdbfe; + --secondary-300: #93c5fd; + --secondary-400: #60a5fa; + --secondary-500: #3b82f6; + --secondary-600: #2563eb; + --secondary-700: #1d4ed8; + --secondary-800: #1e40af; + --secondary-900: #1e3a8a; + --secondary-950: #1d3660; + --neutral-50: #f0f0f0; /* */ + --neutral-100: #e8e8e3;/* majority of text (neutral gray yellow) */ + --neutral-200: #d0d0d0; + --neutral-300: #b3b5ac; /* top tab /sub text (light accent) */ + --neutral-400: #ffba85;/* tab title (bright orange) */ + --neutral-500: #48665b; /* prompt text (desat accent)*/ + --neutral-600: #373f39; /* tab outline color (accent color)*/ + --neutral-700: #2b373b; /* small settings tab accent */ + --neutral-800: #f379c2; /* bright pink accent */ + --neutral-900: #111827; + --neutral-950: #0b0f19; + --radius-xxs: 0; + --radius-xs: 0; + --radius-md: 0; + --radius-xl: 0; + --radius-xxl: 0; + --body-text-size: var(--text-md); + --body-text-weight: 400; + --embed-radius: var(--radius-lg); + --color-accent: var(--primary-500); + --shadow-drop: 0; + --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1); + --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset; + --block-border-width: 1px; + --block-info-text-size: var(--text-sm); + --block-info-text-weight: 400; + --block-label-border-width: 1px; + --block-label-margin: 0; + --block-label-padding: var(--spacing-sm) var(--spacing-lg); + --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0; + --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px); + --block-label-text-size: var(--text-sm); + --block-label-text-weight: 400; + --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px); + --block-radius: var(--radius-lg); + --block-shadow: var(--shadow-drop); + --block-title-background-fill: none; + --block-title-border-color: none; + --block-title-border-width: 0; + --block-title-padding: 0; + --block-title-radius: none; + --block-title-text-size: var(--text-md); + --block-title-text-weight: 400; + --container-radius: var(--radius-lg); + --form-gap-width: 1px; + --layout-gap: var(--spacing-xxl); + --panel-border-width: 0; + --section-header-text-size: var(--text-md); + --section-header-text-weight: 400; + --checkbox-border-radius: var(--radius-sm); + --checkbox-label-gap: 2px; + --checkbox-label-padding: var(--spacing-md); + --checkbox-label-shadow: var(--shadow-drop); + --checkbox-label-text-size: var(--text-md); + --checkbox-label-text-weight: 400; + --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e"); + --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e"); + --checkbox-shadow: var(--input-shadow); + --error-border-width: 1px; + --input-border-width: 1px; + --input-radius: var(--radius-lg); + --input-text-size: var(--text-md); + --input-text-weight: 400; + --loader-color: var(--color-accent); + --prose-text-size: var(--text-md); + --prose-text-weight: 400; + --prose-header-text-weight: 600; + --slider-color: ; + --table-radius: var(--radius-lg); + --button-large-padding: 2px 6px; + --button-large-radius: var(--radius-lg); + --button-large-text-size: var(--text-lg); + --button-large-text-weight: 400; + --button-shadow: none; + --button-shadow-active: none; + --button-shadow-hover: none; + --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm)); + --button-small-radius: var(--radius-lg); + --button-small-text-size: var(--text-md); + --button-small-text-weight: 400; + --button-transition: none; + --size-9: 64px; + --size-14: 64px; +} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 216673ede..18e401fde 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,7 +16,12 @@ const requestGet = (url, data, handler) => { xhr.send(JSON.stringify(data)); }; -const getENActiveTab = () => gradioApp().getElementById('tab_txt2img').style.display === 'block' ? 'txt2img' : 'img2img'; +const getENActiveTab = () => { + if (gradioApp().getElementById('tab_txt2img').style.display === 'block') return 'txt2img'; + if (gradioApp().getElementById('tab_img2img').style.display === 'block') return 'img2img'; + if (gradioApp().getElementById('tab_control').style.display === 'block') return 'control'; + return ''; +}; const getENActivePage = () => { const tabname = getENActiveTab(); @@ -90,14 +95,23 @@ function readCardDescription(page, item) { }); } -async function filterExtraNetworksForTab(tabname, searchTerm) { +function getCardsForActivePage() { + const pagename = getENActivePage(); + if (!pagename) return []; + const allCards = Array.from(gradioApp().querySelectorAll('.extra-network-cards > .card')); + const cards = allCards.filter((el) => el.dataset.page.toLowerCase().includes(pagename.toLowerCase())); + log('getCardsForActivePage', pagename, cards.length); + return allCards; +} + +async function filterExtraNetworksForTab(searchTerm) { let found = 0; let items = 0; const t0 = performance.now(); const pagename = getENActivePage(); if (!pagename) return; const allPages = Array.from(gradioApp().querySelectorAll('.extra-network-cards')); - const pages = allPages.filter((el) => el.id.includes(pagename.toLowerCase())); + const pages = allPages.filter((el) => el.id.toLowerCase().includes(pagename.toLowerCase())); for (const pg of pages) { const cards = Array.from(pg.querySelectorAll('.card') || []); cards.forEach((elem) => { @@ -158,7 +172,7 @@ function sortExtraNetworks() { const pagename = getENActivePage(); if (!pagename) return 'sort error: unknown page'; const allPages = Array.from(gradioApp().querySelectorAll('.extra-network-cards')); - const pages = allPages.filter((el) => el.id.includes(pagename.toLowerCase())); + const pages = allPages.filter((el) => el.id.toLowerCase().includes(pagename.toLowerCase())); let num = 0; for (const pg of pages) { const cards = Array.from(pg.querySelectorAll('.card') || []); @@ -255,12 +269,23 @@ function refeshDetailsEN(args) { return args; } -// init +// refresh on en show +function refreshENpage() { + if (getCardsForActivePage().length === 0) { + log('refreshENpage'); + const tabname = getENActiveTab(); + const btnRefresh = gradioApp().getElementById(`${tabname}_extra_refresh`); + if (btnRefresh) btnRefresh.click(); + } +} +// init function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector(`#${tabname}_extra_tabs`).classList.add('extra-networks'); + let tabs = gradioApp().querySelector(`#${tabname}_extra_tabs`); + if (tabs) tabs.classList.add('extra-networks'); const en = gradioApp().getElementById(`${tabname}_extra_networks`); - const tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); + tabs = gradioApp().querySelector(`#${tabname}_extra_tabs > div`); + if (!tabs) return; // buttons const btnRefresh = gradioApp().getElementById(`${tabname}_extra_refresh`); @@ -307,7 +332,7 @@ function setupExtraNetworksForTab(tabname) { txtSearchValue.addEventListener('input', (evt) => { if (searchTimer) clearTimeout(searchTimer); searchTimer = setTimeout(() => { - filterExtraNetworksForTab(tabname, txtSearchValue.value.toLowerCase()); + filterExtraNetworksForTab(txtSearchValue.value.toLowerCase()); searchTimer = null; }, 150); }); @@ -332,13 +357,14 @@ function setupExtraNetworksForTab(tabname) { }; // en style + if (!en) return; const intersectionObserver = new IntersectionObserver((entries) => { - if (!en) return; for (const el of Array.from(gradioApp().querySelectorAll('.extra-networks-page'))) { el.style.height = `${window.opts.extra_networks_height}vh`; el.parentElement.style.width = '-webkit-fill-available'; } if (entries[0].intersectionRatio > 0) { + refreshENpage(); if (window.opts.extra_networks_card_cover === 'cover') { en.style.transition = ''; en.style.zIndex = 100; @@ -375,9 +401,11 @@ function setupExtraNetworksForTab(tabname) { function setupExtraNetworks() { setupExtraNetworksForTab('txt2img'); setupExtraNetworksForTab('img2img'); + setupExtraNetworksForTab('control'); function registerPrompt(tabname, id) { const textarea = gradioApp().querySelector(`#${id} > label > textarea`); + if (!textarea) return; if (!activePromptTextarea[tabname]) activePromptTextarea[tabname] = textarea; textarea.addEventListener('focus', () => { activePromptTextarea[tabname] = textarea; }); } @@ -386,6 +414,8 @@ function setupExtraNetworks() { registerPrompt('txt2img', 'txt2img_neg_prompt'); registerPrompt('img2img', 'img2img_prompt'); registerPrompt('img2img', 'img2img_neg_prompt'); + registerPrompt('control', 'control_prompt'); + registerPrompt('control', 'control_neg_prompt'); log('initExtraNetworks'); } diff --git a/javascript/imageParams.js b/javascript/imageParams.js index c93523988..0b739cd00 100644 --- a/javascript/imageParams.js +++ b/javascript/imageParams.js @@ -9,16 +9,15 @@ async function initDragDrop() { if (!target.placeholder) return; if (target.placeholder.indexOf('Prompt') === -1) return; const promptTarget = get_tab_index('tabs') === 1 ? 'img2img_prompt_image' : 'txt2img_prompt_image'; - e.stopPropagation(); - e.preventDefault(); const imgParent = gradioApp().getElementById(promptTarget); - if (!imgParent) return; - const { files } = e.dataTransfer; const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { - fileInput.files = files; + if (!imgParent || !fileInput) return; + if ((e.dataTransfer?.files?.length || 0) > 0) { + e.stopPropagation(); + e.preventDefault(); + fileInput.files = e.dataTransfer.files; fileInput.dispatchEvent(new Event('change')); - log('dropEvent'); + log('dropEvent files', fileInput.files); } }); } diff --git a/javascript/invoked.css b/javascript/invoked.css index ab5b48c06..d93f04ccf 100644 --- a/javascript/invoked.css +++ b/javascript/invoked.css @@ -113,9 +113,6 @@ button.selected {background: var(--button-primary-background-fill);} #txt2img_cfg_scale { min-width: 200px; } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } diff --git a/javascript/light-teal.css b/javascript/light-teal.css index ae582231f..4829791f9 100644 --- a/javascript/light-teal.css +++ b/javascript/light-teal.css @@ -79,7 +79,7 @@ svg.feather.feather-image, .feather .feather-image { display: none } .px-4 { padding-lefT: 1rem; padding-right: 1rem; } .py-6 { padding-bottom: 0; } .tabs { background-color: var(--background-color); } -.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; } +.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.7rem; } .tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; } .label-wrap { margin: 16px 0px 8px 0px; } .gradio-button.tool { border: none; background: none; box-shadow: none; filter: hue-rotate(340deg) saturate(0.5); } @@ -115,9 +115,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } #txt2img_cfg_scale { min-width: 200px; } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } diff --git a/javascript/logMonitor.js b/javascript/logMonitor.js index 7677cfced..af744aee1 100644 --- a/javascript/logMonitor.js +++ b/javascript/logMonitor.js @@ -9,8 +9,15 @@ async function logMonitor() { try { res = await fetch('/sdapi/v1/log?clear=True'); } catch {} if (res?.ok) { logMonitorStatus = true; - if (!logMonitorEl) logMonitorEl = document.getElementById('logMonitorData'); + if (!logMonitorEl) { + logMonitorEl = document.getElementById('logMonitorData'); + logMonitorEl.onscrollend = () => { + const at_bottom = logMonitorEl.scrollHeight <= (logMonitorEl.scrollTop + logMonitorEl.clientHeight); + if (at_bottom) logMonitorEl.parentElement.style = ''; + }; + } if (!logMonitorEl) return; + const at_bottom = logMonitorEl.scrollHeight <= (logMonitorEl.scrollTop + logMonitorEl.clientHeight); const lines = await res.json(); if (logMonitorEl && lines?.length > 0) logMonitorEl.parentElement.parentElement.style.display = opts.logmonitor_show ? 'block' : 'none'; for (const line of lines) { @@ -23,7 +30,8 @@ async function logMonitor() { } catch {} } while (logMonitorEl.childElementCount > 100) logMonitorEl.removeChild(logMonitorEl.firstChild); - logMonitorEl.scrollTop = logMonitorEl.scrollHeight; + if (at_bottom) logMonitorEl.scrollTop = logMonitorEl.scrollHeight; + else if (lines?.length > 0) logMonitorEl.parentElement.style = 'border-bottom: 2px solid var(--highlight-color);'; } } diff --git a/javascript/midnight-barbie.css b/javascript/midnight-barbie.css index 8f47fe262..d988a1275 100644 --- a/javascript/midnight-barbie.css +++ b/javascript/midnight-barbie.css @@ -106,10 +106,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } #txt2img_cfg_scale { min-width: 200px; } #txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } #txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } -#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } -#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } -#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } - #extras_upscale { margin-top: 10px } #txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } #txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; } diff --git a/javascript/orchid-dreams.css b/javascript/orchid-dreams.css new file mode 100644 index 000000000..8eab006cb --- /dev/null +++ b/javascript/orchid-dreams.css @@ -0,0 +1,297 @@ +/* generic html tags */ +:root, .light, .dark { + --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif; + --font-mono: 'ui-monospace', 'Consolas', monospace; + --font-size: 16px; + --primary-100: #2a2a34; /* bg color*/ + --primary-200: #1f2028; /* drop down menu/ prompt*/ + --primary-300: #0a0c0e; /* black */ + --primary-400: #40435c; /* small buttons*/ + --primary-500: #4c48b5; /* main accent color purple*/ + --primary-700: #1f2028; /* darker hover accent*/ + --primary-800: #e95ee3; /* pink accent*/ + --highlight-color: var(--primary-500); + --inactive-color: var(--primary--800); + --body-text-color: var(--neutral-100); + --body-text-color-subdued: var(--neutral-300); + --background-color: var(--primary-100); + --background-fill-primary: var(--input-background-fill); + --input-padding: 8px; + --input-background-fill: var(--primary-200); + --input-shadow: none; + --button-secondary-text-color: white; + --button-secondary-background-fill: var(--primary-400); + --button-secondary-background-fill-hover: var(--primary-700); + --block-title-text-color: var(--neutral-300); + --radius-sm: 1px; + --radius-lg: 6px; + --spacing-md: 4px; + --spacing-xxl: 8px; + --line-sm: 1.2em; + --line-md: 1.4em; +} + +html { font-size: var(--font-size); } +body, button, input, select, textarea { font-family: var(--font);} +button { font-size: 1.2rem; max-width: 400px; } +img { background-color: var(--background-color); } +input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; } +input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; } +input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; } +input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); } +::-webkit-scrollbar-track { background: #333333; } +::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; } +div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; } +div.compact { gap: 1em; } + +/* gradio style classes */ +fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; } +.border-2 { border-width: 0; } +.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; } +.bg-white { color: lightyellow; background-color: var(--inactive-color); } +.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px } +.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; } +.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; } +.gr-check-radio:checked { background-color: var(--highlight-color); } +.gr-compact { background-color: var(--background-color); } +.gr-form { border-width: 0; } +.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; } +.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; } +.gr-panel { background-color: var(--background-color); } +.eta-bar { display: none !important } +svg.feather.feather-image, .feather .feather-image { display: none } +.gap-2 { padding-top: 8px; } +.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; } +.output-html { line-height: 1.2rem; overflow-x: hidden; } +.output-html > div { margin-bottom: 8px; } +.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */ +.p-2 { padding: 0; } +.px-4 { padding-lefT: 1rem; padding-right: 1rem; } +.py-6 { padding-bottom: 0; } +.tabs { background-color: var(--background-color); } +.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; } +.tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; } +div.tab-nav button.selected {background-color: var(--button-primary-background-fill);} +#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;} +.label-wrap { background-color: #18181e; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; } +.small-accordion .label-wrap { padding: 8px 0px 8px 0px; } +.small-accordion .label-wrap .icon { margin-right: 1em; } +.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);} +button.selected {background: var(--button-primary-background-fill);} +.center.boundedheight.flex {background-color: var(--input-background-fill);} +.compact {border-radius: var(--border-radius-lg);} +#logMonitorData {background-color: var(--input-background-fill);} +#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); } +#tab_extensions table, #tab_config table { width: 96vw; } +#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;} +#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);} + +/* automatic style classes */ +.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); } +.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; } +.gallery-item { box-shadow: none !important; } +.performance { color: #888; } +.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; } +.image-buttons { gap: 10px !important; justify-content: center; } +.image-buttons > button { max-width: 160px; } +.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) } +#system_row > button, #settings_row > button, #config_row > button { max-width: 190px; } + +/* gradio elements overrides */ +#div.gradio-container { overflow-x: hidden; } +#img2img_label_copy_to_img2img { font-weight: normal; } +#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; } +#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; } +#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; } +#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; } +#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; } +#quicksettings button {padding: 0 0.5em 0.1em 0.5em;} +#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; } +#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; } +#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; } +#settings > div.flex-wrap { width: 15em; } +#txt2img_cfg_scale { min-width: 200px; } +#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } +#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } +#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } +#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } +#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } + +#extras_upscale { margin-top: 10px } +#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } +#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; } +#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) } +#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; } +#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; } + +/* based on gradio built-in dark theme */ +:root, .light, .dark { + --body-background-fill: var(--background-color); + --color-accent-soft: var(--neutral-700); + --background-fill-secondary: none; + --border-color-accent: var(--background-color); + --border-color-primary: var(--background-color); + --link-text-color-active: var(--primary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --shadow-spread: 1px; + --block-background-fill: None; + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--primary-200); + --checkbox-background-color-focus: var(--primary-400); + --checkbox-background-color-hover: var(--primary-200); + --checkbox-background-color-selected: var(--primary-400); + --checkbox-border-color: transparent; + --checkbox-border-color-focus: var(--primary-800); + --checkbox-border-color-hover: var(--primary-800); + --checkbox-border-color-selected: var(--primary-800); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: None; + --checkbox-label-background-fill-hover: None; + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error-text-color: #f768b7; /*was ef4444*/ + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--background-color); + --input-border-color-focus: var(--primary-800); + --input-placeholder-color: var(--neutral-500); + --input-shadow-focus: None; + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--primary-300); + --table-odd-background-fill: var(--primary-200); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: var(--primary-500); + --button-primary-background-fill-hover: var(--primary-800); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --secondary-50: #eff6ff; + --secondary-100: #dbeafe; + --secondary-200: #bfdbfe; + --secondary-300: #93c5fd; + --secondary-400: #60a5fa; + --secondary-500: #3b82f6; + --secondary-600: #2563eb; + --secondary-700: #1d4ed8; + --secondary-800: #1e40af; + --secondary-900: #1e3a8a; + --secondary-950: #1d3660; + --neutral-50: #f0f0f0; /* */ + --neutral-100: #ddd5e8;/* majority of text (neutral gray purple) */ + --neutral-200: #d0d0d0; + --neutral-300: #bfbad6; /* top tab text (light accent) */ + --neutral-400: #ffba85;/* tab title (bright orange) */ + --neutral-500: #545b94; /* prompt text (desat accent)*/ + --neutral-600: #1f2028; /* tab outline color (accent color)*/ + --neutral-700: #20212c; /* unchanged settings tab accent (dark)*/ + --neutral-800: #e055dc; /* bright pink accent */ + --neutral-900: #111827; + --neutral-950: #0b0f19; + --radius-xxs: 0; + --radius-xs: 0; + --radius-md: 0; + --radius-xl: 0; + --radius-xxl: 0; + --body-text-size: var(--text-md); + --body-text-weight: 400; + --embed-radius: var(--radius-lg); + --color-accent: var(--primary-500); + --shadow-drop: 0; + --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1); + --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset; + --block-border-width: 1px; + --block-info-text-size: var(--text-sm); + --block-info-text-weight: 400; + --block-label-border-width: 1px; + --block-label-margin: 0; + --block-label-padding: var(--spacing-sm) var(--spacing-lg); + --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0; + --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px); + --block-label-text-size: var(--text-sm); + --block-label-text-weight: 400; + --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px); + --block-radius: var(--radius-lg); + --block-shadow: var(--shadow-drop); + --block-title-background-fill: none; + --block-title-border-color: none; + --block-title-border-width: 0; + --block-title-padding: 0; + --block-title-radius: none; + --block-title-text-size: var(--text-md); + --block-title-text-weight: 400; + --container-radius: var(--radius-lg); + --form-gap-width: 1px; + --layout-gap: var(--spacing-xxl); + --panel-border-width: 0; + --section-header-text-size: var(--text-md); + --section-header-text-weight: 400; + --checkbox-border-radius: var(--radius-sm); + --checkbox-label-gap: 2px; + --checkbox-label-padding: var(--spacing-md); + --checkbox-label-shadow: var(--shadow-drop); + --checkbox-label-text-size: var(--text-md); + --checkbox-label-text-weight: 400; + --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e"); + --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e"); + --checkbox-shadow: var(--input-shadow); + --error-border-width: 1px; + --input-border-width: 1px; + --input-radius: var(--radius-lg); + --input-text-size: var(--text-md); + --input-text-weight: 400; + --loader-color: var(--color-accent); + --prose-text-size: var(--text-md); + --prose-text-weight: 400; + --prose-header-text-weight: 600; + --slider-color: ; + --table-radius: var(--radius-lg); + --button-large-padding: 2px 6px; + --button-large-radius: var(--radius-lg); + --button-large-text-size: var(--text-lg); + --button-large-text-weight: 400; + --button-shadow: none; + --button-shadow-active: none; + --button-shadow-hover: none; + --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm)); + --button-small-radius: var(--radius-lg); + --button-small-text-size: var(--text-md); + --button-small-text-weight: 400; + --button-transition: none; + --size-9: 64px; + --size-14: 64px; +} diff --git a/javascript/progressBar.js b/javascript/progressBar.js index 0133d3c1f..85c339f26 100644 --- a/javascript/progressBar.js +++ b/javascript/progressBar.js @@ -40,7 +40,7 @@ function checkPaused(state) { } function setProgress(res) { - const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate']; + const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate', 'control_generate']; const progress = (res?.progress || 0); const job = res?.job || ''; const perc = res && (progress > 0) ? `${Math.round(100.0 * progress)}%` : ''; @@ -57,10 +57,12 @@ function setProgress(res) { document.title = `SD.Next ${perc}`; for (const elId of elements) { const el = document.getElementById(elId); - el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate'); - el.style.background = res && (progress > 0) - ? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})` - : 'var(--button-primary-background-fill)'; + if (el) { + el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate'); + el.style.background = res && (progress > 0) + ? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})` + : 'var(--button-primary-background-fill)'; + } } } diff --git a/javascript/promptChecker.js b/javascript/promptChecker.js index d18a030ec..284482205 100644 --- a/javascript/promptChecker.js +++ b/javascript/promptChecker.js @@ -36,4 +36,6 @@ onAfterUiUpdate(() => { setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); setupBracketChecking('img2img_prompt', 'img2img_token_counter'); setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); + setupBracketChecking('control_prompt', 'control_token_counter'); + setupBracketChecking('control_neg_prompt', 'control_negative_token_counter'); }); diff --git a/javascript/sdnext.css b/javascript/sdnext.css index 769d9da16..8ebaf10c5 100644 --- a/javascript/sdnext.css +++ b/javascript/sdnext.css @@ -48,11 +48,11 @@ textarea { overflow-y: auto !important; } /* custom gradio elements */ .accordion-compact { padding: 8px 0px 4px 0px !important; } -.settings-accordion > div { flex-flow: wrap; } +.settings-accordion>div { flex-flow: wrap; } .small-accordion .form { min-width: var(--left-column) !important; max-width: max-content; } .small-accordion .label-wrap .icon { margin-right: 1.6em; margin-left: 0.6em; color: var(--button-primary-border-color); } .small-accordion .label-wrap { padding: 16px 0px 8px 0px; margin: 0; border-top: 2px solid var(--button-secondary-border-color); } -.small-accordion { width: fit-content !important; padding-left: 0 !important; } +.small-accordion { width: fit-content !important; min-width: fit-content !important; padding-left: 0 !important; } .extension-script { max-width: 48vw; } button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); font-weight: var(--button-large-text-weight); border: var(--button-border-width) solid var(--button-secondary-border-color); background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); font-size: var(--button-large-text-size); @@ -62,7 +62,7 @@ button.custom-button{ border-radius: var(--button-large-radius); padding: var(-- .theme-preview { display: none; position: fixed; border: var(--spacing-sm) solid var(--neutral-600); box-shadow: 2px 2px 2px 2px var(--neutral-700); top: 0; bottom: 0; left: 0; right: 0; margin: auto; max-width: 75vw; z-index: 999; } /* txt2img/img2img specific */ -.block.token-counter{ position: absolute; display: inline-block; right: 0; min-width: 0 !important; width: auto; z-index: 100; top: -0.75em; } +.block.token-counter{ position: absolute; display: inline-block; right: 1em; min-width: 0 !important; width: auto; z-index: 100; top: -0.5em; } .block.token-counter span{ background: var(--input-background-fill) !important; box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); border: 2px solid rgba(192,192,192,0.4) !important; } .block.token-counter.error span{ box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); border: 2px solid rgba(255,0,0,0.4) !important; } .block.token-counter div{ display: inline; } @@ -70,18 +70,24 @@ button.custom-button{ border-radius: var(--button-large-radius); padding: var(-- .performance { font-size: 0.85em; color: #444; } .performance p { display: inline-block; color: var(--body-text-color-subdued) !important } .performance .time { margin-right: 0; } -#txt2img_prompt_container, #img2img_prompt_container { margin-right: var(--layout-gap) } -#txt2img_footer, #img2img_footer, #extras_footer { height: fit-content; } -#txt2img_footer, #img2img_footer { height: fit-content; display: none; } -#txt2img_generate_box, #img2img_generate_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; } -#txt2img_actions_column, #img2img_actions_column { gap: 0.5em; height: fit-content; } -#txt2img_generate_box > button, #img2img_generate_box > button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; } -#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools { display: flex; } -#txt2img_generate_line2 > button, #img2img_generate_line2 > button, #extras_generate_box > button, #txt2img_tools > button, #img2img_tools > button { height: 2em; line-height: 0; font-size: var(--input-text-size); +.thumbnails { background: var(--body-background-fill); } +#control_gallery { height: 564px; } +#control-result { padding: 0.5em; } +#control-inputs { margin-top: 1em; } +#txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { margin-right: var(--layout-gap) } +#txt2img_footer, #img2img_footer, #extras_footer, #control_footer { height: fit-content; display: none; } +#txt2img_generate_box, #img2img_generate_box, #control_general_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; } +#txt2img_actions_column, #img2img_actions_column, #control_actions_column { gap: 0.3em; height: fit-content; } +#txt2img_generate_box>button, #img2img_generate_box>button, #control_generate_box>button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; } +#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools, #control_generate_line2, #control_tools { display: flex; } +#txt2img_generate_line2>button, #img2img_generate_line2>button, #extras_generate_box>button, #control_generate_line2>button, #txt2img_tools>button, #img2img_tools>button, #control_tools>button { height: 2em; line-height: 0; font-size: var(--input-text-size); min-width: unset; display: block !important; } -#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { display: contents; } +#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt, #control_prompt, #control_neg_prompt { display: contents; } +#txt2img_generate_box, #img2img_generate_box, #control_generate_box { min-width: unset; width: 48%; } +#txt2img_actions_column, #img2img_actions_column, #control_actions { flex-flow: wrap; justify-content: space-between; } +#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper, #control_enqueue_wrapper { min-width: unset !important; width: 48%; } .interrogate-col{ min-width: 0 !important; max-width: fit-content; margin-right: var(--spacing-xxl); } -.interrogate-col > button{ flex: 1; } +.interrogate-col>button{ flex: 1; } #sampler_selection_img2img { margin-top: 1em; } #txtimg_hr_finalres{ min-height: 0 !important; } #img2img_scale_resolution_preview.block{ display: flex; align-items: end; } @@ -89,13 +95,12 @@ button.custom-button{ border-radius: var(--button-large-radius); padding: var(-- div#extras_scale_to_tab div.form{ flex-direction: row; } #img2img_unused_scale_by_slider { visibility: hidden; width: 0.5em; max-width: 0.5em; min-width: 0.5em; } .inactive{ opacity: 0.5; } -div.dimensions-tools { min-width: 0 !important; max-width: fit-content; flex-direction: row; align-content: center; } div#extras_scale_to_tab div.form{ flex-direction: row; } -#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; } +#mode_img2img .gradio-image>div.fixed-height, #mode_img2img .gradio-image>div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; } #img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; } .image-buttons button{ min-width: auto; } .infotext { overflow-wrap: break-word; line-height: 1.5em; } -.infotext > p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; } +.infotext>p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; } .tooltip { display: block; position: fixed; top: 1em; right: 1em; padding: 0.5em; background: var(--input-background-fill); color: var(--body-text-color); border: 1pt solid var(--button-primary-border-color); width: 22em; min-height: 1.3em; font-size: 0.8em; transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; } .tooltip-show { opacity: 0.9; } @@ -104,15 +109,15 @@ div#extras_scale_to_tab div.form{ flex-direction: row; } /* settings */ #si-sparkline-memo, #si-sparkline-load { background-color: #111; } #quicksettings { width: fit-content; } -#quicksettings > button { padding: 0 1em 0 0; align-self: end; margin-bottom: var(--text-sm); } +#quicksettings>button { padding: 0 1em 0 0; align-self: end; margin-bottom: var(--text-sm); } #settings { display: flex; gap: var(--layout-gap); } #settings div { border: none; gap: 0; margin: 0 0 var(--layout-gap) 0px; padding: 0; } -#settings > div.tab-content { flex: 10 0 75%; display: grid; } -#settings > div.tab-content > div { border: none; padding: 0; } -#settings > div.tab-content > div > div > div > div > div { flex-direction: unset; } -#settings > div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); } -#settings > div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; } -#settings > div.tab-nav > #settings_show_all_pages { padding: var(--size-2) var(--size-4); } +#settings>div.tab-content { flex: 10 0 75%; display: grid; } +#settings>div.tab-content>div { border: none; padding: 0; } +#settings>div.tab-content>div>div>div>div>div { flex-direction: unset; } +#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); } +#settings>div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; } +#settings>div.tab-nav>#settings_show_all_pages { padding: var(--size-2) var(--size-4); } #settings .block.gradio-checkbox { margin: 0; width: auto; } #settings .dirtyable { gap: .5em; } #settings .dirtyable.hidden { display: none; } @@ -146,8 +151,8 @@ div#extras_scale_to_tab div.form{ flex-direction: row; } .modalControls span:hover, .modalControls span:focus { color: var(--highlight-color); filter: none; } .lightboxModalPreviewZone { display: flex; width: 100%; height: 100%; } .lightboxModalPreviewZone:focus-visible { outline: none; } -.lightboxModalPreviewZone > img { display: block; margin: auto; width: auto; } -.lightboxModalPreviewZone > img.modalImageFullscreen{ object-fit: contain; height: 100%; width: 100%; min-height: 0; background: transparent; } +.lightboxModalPreviewZone>img { display: block; margin: auto; width: auto; } +.lightboxModalPreviewZone>img.modalImageFullscreen{ object-fit: contain; height: 100%; width: 100%; min-height: 0; background: transparent; } table.settings-value-table { background: white; border-collapse: collapse; margin: 1em; border: var(--spacing-sm) solid white; } table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-width: 36em; } .modalPrev, .modalNext { cursor: pointer; position: relative; z-index: 1; top: 0; width: auto; height: 100vh; line-height: 100vh; text-align: center; padding: 16px; @@ -178,13 +183,13 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt #extensions .date{ opacity: 0.85; font-size: 90%; } /* extra networks */ -.extra-networks > div { margin: 0; border-bottom: none !important; gap: 0.3em 0; } +.extra-networks>div { margin: 0; border-bottom: none !important; gap: 0.3em 0; } .extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); } .extra-networks .search { flex: 1; } .extra-networks .description { flex: 3; } -.extra-networks .tab-nav > button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; } +.extra-networks .tab-nav>button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; } .extra-networks .buttons { position: absolute; right: 0; margin: -4px; background: var(--background-color); } -.extra-networks .buttons > button { margin-left: -0.2em; height: 1.4em; color: var(--primary-300) !important; } +.extra-networks .buttons>button { margin-left: -0.2em; height: 1.4em; color: var(--primary-300) !important; } .extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 3px 3px 3px 12px; text-indent: -6px; box-shadow: none; line-break: auto; } .extra-networks .custom-button:hover { background: var(--button-primary-background-fill) } .extra-networks-tab { padding: 0 !important; } @@ -200,18 +205,18 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); } .extra-network-cards .card .overlay .tags { display: none; overflow-wrap: break-word; } .extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; } -.extra-network-cards .card .actions > span { padding: 4px; } -.extra-network-cards .card .actions > span:hover { color: var(--highlight-color); } +.extra-network-cards .card .actions>span { padding: 4px; } +.extra-network-cards .card .actions>span:hover { color: var(--highlight-color); } .extra-network-cards .card:hover .actions { display: block; } .extra-network-cards .card:hover .overlay .tags { display: block; } .extra-network-cards .card .actions { font-size: 3em; display: none; text-align-last: right; cursor: pointer; font-variant: unicase; position: absolute; z-index: 100; right: 0; height: 0.7em; width: 100%; background: rgba(0, 0, 0, 0.40); } .extra-network-cards .card-list { display: flex; margin: 0.3em; padding: 0.3em; background: var(--input-background-fill); cursor: pointer; border-radius: var(--button-large-radius); } .extra-network-cards .card-list .tag { color: var(--primary-500); margin-left: 0.8em; } .extra-details-close { position: fixed; top: 0.2em; right: 0.2em; z-index: 99; background: var(--button-secondary-background-fill) !important; } -#txt2img_description, #img2img_description { max-height: 63px; overflow-y: auto !important; } -#txt2img_description > label > textarea, #img2img_description > label > textarea { font-size: 0.9em } +#txt2img_description, #img2img_description, #control_description { max-height: 63px; overflow-y: auto !important; } +#txt2img_description>label>textarea, #img2img_description>label>textarea, #control_description>label>textarea { font-size: 0.9em } -#txt2img_extra_details > div, #img2img_extra_details > div { overflow-y: auto; min-height: 40vh; max-height: 80vh; align-self: flex-start; } +#txt2img_extra_details>div, #img2img_extra_details>div { overflow-y: auto; min-height: 40vh; max-height: 80vh; align-self: flex-start; } #txt2img_extra_details, #img2img_extra_details { position: fixed; bottom: 50%; left: 50%; transform: translate(-50%, 50%); padding: 0.8em; border: var(--block-border-width) solid var(--highlight-color) !important; z-index: 100; box-shadow: var(--button-shadow); } #txt2img_extra_details td:first-child, #img2img_extra_details td:first-child { font-weight: bold; vertical-align: top; } @@ -221,14 +226,14 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt /* specific elements */ #modelmerger_interp_description { margin-top: 1em; margin-bottom: 1em; } #scripts_alwayson_txt2img, #scripts_alwayson_img2img { padding: 0 } -#scripts_alwayson_txt2img > .label-wrap, #scripts_alwayson_img2img > .label-wrap { background: var(--input-background-fill); padding: 0; margin: 0; border-radius: var(--radius-lg); } -#scripts_alwayson_txt2img > .label-wrap > span, #scripts_alwayson_img2img > .label-wrap > span { padding: var(--spacing-xxl); } +#scripts_alwayson_txt2img>.label-wrap, #scripts_alwayson_img2img>.label-wrap { background: var(--input-background-fill); padding: 0; margin: 0; border-radius: var(--radius-lg); } +#scripts_alwayson_txt2img>.label-wrap>span, #scripts_alwayson_img2img>.label-wrap>span { padding: var(--spacing-xxl); } #scripts_alwayson_txt2img div { max-width: var(--left-column); } #script_txt2img_agent_scheduler { display: none; } #refresh_tac_refreshTempFiles { display: none; } #train_tab { flex-flow: row-reverse; } #models_tab { flex-flow: row-reverse; } -#swap_axes > button { min-width: 100px; font-size: 1em; } +#swap_axes>button { min-width: 100px; font-size: 1em; } #ui_defaults_review { margin: 1em; } /* extras */ @@ -251,6 +256,17 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .nvml { position: fixed; bottom: 10px; right: 10px; background: var(--background-fill-primary); border: 1px solid var(--button-primary-border-color); padding: 6px; color: var(--button-primary-text-color); font-size: 0.7em; z-index: 50; font-family: monospace; display: none; } +/* control */ +#control_input_type { max-width: 18em } +#control_settings .small-accordion .form { min-width: 350px !important } +.control-button { min-height: 42px; max-height: 42px; line-height: 1em; } +.control-tabs>.tab-nav { margin-bottom: 0; margin-top: 0; } +.processor-settings { padding: 0 !important; max-width: 300px; } +.processor-group>div { flex-flow: wrap;gap: 1em; } + +/* main info */ +.main-info { font-weight: var(--section-header-text-weight); color: var(--body-text-color-subdued); padding: 1em !important; margin-top: 2em !important; line-height: var(--line-lg) !important; } + /* loader */ .splash { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; z-index: 1000; display: block; text-align: center; } .motd { margin-top: 2em; color: var(--body-text-color-subdued); font-family: monospace; font-variant: all-petite-caps; } @@ -280,76 +296,40 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt --spacing-xxl: 6px; } -/* Apply different styles for devices with coarse pointers dependant on screen resolution */ -@media (hover: none) and (pointer: coarse) { - - /* Do not affect displays larger than 1024px wide. */ - @media (max-width: 1024px) { - - /* Screens smaller than 400px wide */ - @media (max-width: 399px) { +@media (hover: none) and (pointer: coarse) { /* Apply different styles for devices with coarse pointers dependant on screen resolution */ + @media (max-width: 1024px) { /* Do not affect displays larger than 1024px wide. */ + @media (max-width: 399px) { /* Screens smaller than 400px wide */ :root, .light, .dark { --left-column: 100%; } - - /* maintain single column for from image operations on larger mobile devices */ - #txt2img_results, #img2img_results, #extras_results { min-width: calc(min(320px, 100%)) !important;} + #txt2img_results, #img2img_results, #extras_results { min-width: calc(min(320px, 100%)) !important;} /* maintain single column for from image operations on larger mobile devices */ #txt2img_footer p { text-wrap: wrap; } - - } - - /* Screens larger than 400px wide */ - @media (min-width: 400px) { + } + @media (min-width: 400px) { /* Screens larger than 400px wide */ :root, .light, .dark {--left-column: 50%;} - - /* maintain side by side split on larger mobile displays for from text */ - #txt2img_results, #extras_results, #txt2img_footer p {text-wrap: wrap; max-width: 100% !important; } + #txt2img_results, #extras_results, #txt2im g_footer p {text-wrap: wrap; max-width: 100% !important; } /* maintain side by side split on larger mobile displays for from text */ } - #scripts_alwayson_txt2img div, #scripts_alwayson_img2img div { max-width: 100%; } - #txt2img_prompt_container, #img2img_prompt_container { resize:vertical !important; } - - /* make generate and enqueue buttons take up the entire width of their rows. */ - #txt2img_generate_box, #txt2img_enqueue_wrapper { min-width: 100% !important;} - - /*make interrogate buttons take up appropriate space. */ - #img2img_toprow > div.gradio-column {flex-grow: 1 !important;} + #txt2img_prompt_container, #img2img_prompt_container, #control_prompt_container { resize:vertical !important; } + #txt2img_generate_box, #txt2img_enqueue_wrapper { min-width: 100% !important;} /* make generate and enqueue buttons take up the entire width of their rows. */ + #img2img_toprow>div.gradio-column {flex-grow: 1 !important;} /*make interrogate buttons take up appropriate space. */ #img2img_actions_column {display: flex; min-width: fit-content !important; flex-direction: row;justify-content: space-evenly; align-items: center;} #txt2img_generate_box, #img2img_generate_box, #txt2img_enqueue_wrapper,#img2img_enqueue_wrapper {display: flex;flex-direction: column;height: 4em !important;align-items: stretch;justify-content: space-evenly;} - - /* maintain single column for from image operations on larger mobile devices */ - #img2img_interface, #img2img_results, #img2img_footer p {text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} - /* fix inpaint image display being too large for mobile displays */ - #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; alignment-baseline:after-edge !important; overflow: auto !important; resize: none !important; } + #img2img_interface, #img2img_results, #img2img_footer p {text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} /* maintain single column for from image operations on larger mobile devices */ + #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; overflow: auto !important; resize: none !important; } /* fix inpaint image display being too large for mobile displays */ #img2maskimg canvas { width: auto !important; max-height: 100% !important; height: auto !important; } - - /* fix from text/image UI elements to prevent them from moving around within the UI */ - #txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } - #img2img_resize_group .gradio-radio > div { display: flex; flex-direction: column; width: unset !important; } + #txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } /* fix from text/image UI elements to prevent them from moving around within the UI */ + #img2img_resize_group .gradio-radio>div { display: flex; flex-direction: column; width: unset !important; } #inpaint_controls div {display:flex;flex-direction: row;} - #inpaint_controls .gradio-radio > div { display: flex; flex-direction: column !important; } - - /* move image preview/output on models page to bottom of page */ - #models_tab { flex-direction: column-reverse !important; } - /* fix settings for agent scheduler */ - #enqueue_keyboard_shortcut_modifiers, #enqueue_keyboard_shortcut_key div { max-width: 40% !important;} - - /* adjust width of certain settings item to allow aligning as row, but not have it go off the screen */ - #settings { display: flex; flex-direction: row; flex-wrap: wrap; max-width: 100% !important; } - #settings div.tab-content > div > div > div { max-width: 80% !important;} + #inpaint_controls .gradio-radio>div { display: flex; flex-direction: column !important; } + #models_tab { flex-direction: column-reverse !important; } /* move image preview/output on models page to bottom of page */ + #enqueue_keyboard_shortcut_modifiers, #enqueue_keyboard_shortcut_key div { max-width: 40% !important;} /* fix settings for agent scheduler */ + #settings { display: flex; flex-direction: row; flex-wrap: wrap; max-width: 100% !important; } /* adjust width of certain settings item to allow aligning as row, but not have it go off the screen */ + #settings div.tab-content>div>div>div { max-width: 80% !important;} #settings div .gradio-radio { width: unset !important; } - - /* enable scrolling on extensions tab */ - #tab_extensions table { border-collapse: collapse; display: block; overflow-x:auto !important;} - - /* increase scrollbar size to make it finger friendly */ - ::-webkit-scrollbar { width: 25px !important; height:25px; } - - /* adjust dropdown size to make them easier to select individual items on mobile. */ - .gradio-dropdown ul.options {max-height: 41vh !important; } + #tab_extensions table { border-collapse: collapse; display: block; overflow-x:auto !important;} /* enable scrolling on extensions tab */ + ::-webkit-scrollbar { width: 25px !important; height:25px; } /* increase scrollbar size to make it finger friendly */ + .gradio-dropdown ul.options {max-height: 41vh !important; } /* adjust dropdown size to make them easier to select individual items on mobile. */ .gradio-dropdown ul.options li.item {height: 40px !important; display: flex; align-items: center;} - - /* adjust slider input fields as they were too large for mobile devices. */ - .gradio-slider input[type="number"] { width: 4em; font-size: 0.8rem; height: 16px; text-align: center; } + .gradio-slider input[type="number"] { width: 4em; font-size: 0.8rem; height: 16px; text-align: center; } /* adjust slider input fields as they were too large for mobile devices. */ #txt2img_settings .block .padded:not(.gradio-accordion) {padding: 0 !important;margin-right: 0; min-width: 100% !important; width:100% !important;} - } - + } } diff --git a/javascript/settings.js b/javascript/settings.js index 93e4bebce..a787161dc 100644 --- a/javascript/settings.js +++ b/javascript/settings.js @@ -108,7 +108,7 @@ onAfterUiUpdate(async () => { const settingsSearch = gradioApp().querySelectorAll('#settings_search > label > textarea')[0]; settingsSearch.oninput = (e) => { setTimeout(() => { - log('settingsSearch', e.target.value) + log('settingsSearch', e.target.value); showAllSettings(); gradioApp().querySelectorAll('#tab_settings .tabitem').forEach((section) => { section.querySelectorAll('.dirtyable').forEach((setting) => { @@ -129,6 +129,40 @@ onOptionsChanged(() => { }); }); +async function initModels() { + const warn = () => ` +

No models available

+ - Select a model from reference list to download or
+ - Set model path to a folder containing your models
+ Current model path: ${opts.ckpt_dir}
+ `; + const el = gradioApp().getElementById('main_info'); + const en = gradioApp().getElementById('txt2img_extra_networks'); + if (!el || !en) return; + const req = await fetch('/sdapi/v1/sd-models'); + const res = req.ok ? await req.json() : []; + log('initModels', res.length); + const ready = () => ` +

Ready

+ ${res.length} models available
+ `; + el.innerHTML = res.length > 0 ? ready() : warn(); + el.style.display = 'block'; + setTimeout(() => el.style.display = 'none', res.length === 0 ? 30000 : 1500); + if (res.length === 0) { + if (en.classList.contains('hide')) gradioApp().getElementById('txt2img_extra_networks_btn').click(); + const repeat = setInterval(() => { + const buttons = Array.from(gradioApp().querySelectorAll('#txt2img_model_subdirs > button')) || []; + const reference = buttons.find((b) => b.innerText === 'Reference'); + if (reference) { + clearInterval(repeat); + reference.click(); + log('enReferenceSelect'); + } + }, 100); + } +} + function initSettings() { if (settingsInitialized) return; settingsInitialized = true; @@ -138,7 +172,7 @@ function initSettings() { const observer = new MutationObserver((mutations) => { const showAllPages = gradioApp().getElementById('settings_show_all_pages'); if (showAllPages.style.display === 'none') return; - const mutation = (mut) => mut.type === 'attributes' && mut.attributeName === 'style' + const mutation = (mut) => mut.type === 'attributes' && mut.attributeName === 'style'; if (mutations.some(mutation)) showAllSettings(); }); const tabContentWrapper = document.createElement('div'); @@ -155,3 +189,4 @@ function initSettings() { } onUiLoaded(initSettings); +onUiLoaded(initModels); diff --git a/javascript/timeless-beige.css b/javascript/timeless-beige.css new file mode 100644 index 000000000..b4b8c57e3 --- /dev/null +++ b/javascript/timeless-beige.css @@ -0,0 +1,297 @@ +/* generic html tags */ +:root, .light, .dark { + --font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif; + --font-mono: 'ui-monospace', 'Consolas', monospace; + --font-size: 16px; + --primary-100: #212226; /* bg color*/ + --primary-200: #17181b; /* drop down menu/ prompt window fill*/ + --primary-300: #0a0c0e; /* black */ + --primary-400: #2f3034; /* small buttons*/ + --primary-500: #434242; /* main accent color retro beige*/ + --primary-700: #e75d5d; /* light blue gray*/ + --primary-800: #e75d5d; /* sat orange(hover accent)*/ + --highlight-color: var(--primary-500); + --inactive-color: var(--primary--800); + --body-text-color: var(--neutral-100); + --body-text-color-subdued: var(--neutral-300); + --background-color: var(--primary-100); + --background-fill-primary: var(--input-background-fill); + --input-padding: 8px; + --input-background-fill: var(--primary-200); + --input-shadow: none; + --button-secondary-text-color: white; + --button-secondary-background-fill: var(--primary-400); + --button-secondary-background-fill-hover: var(--primary-700); + --block-title-text-color: var(--neutral-300); + --radius-sm: 1px; + --radius-lg: 6px; + --spacing-md: 4px; + --spacing-xxl: 8px; + --line-sm: 1.2em; + --line-md: 1.4em; +} + +html { font-size: var(--font-size); } +body, button, input, select, textarea { font-family: var(--font);} +button { font-size: 1.2rem; max-width: 400px; } +img { background-color: var(--background-color); } +input[type=range] { height: var(--line-sm); appearance: none; margin-top: 0; min-width: 160px; background-color: var(--background-color); width: 100%; background: transparent; } +input[type=range]::-webkit-slider-runnable-track, input[type=range]::-moz-range-track { width: 100%; height: 6px; cursor: pointer; background: var(--primary-400); border-radius: var(--radius-lg); border: 0px solid #222222; } +input[type=range]::-webkit-slider-thumb, input[type=range]::-moz-range-thumb { border: 0px solid #000000; height: var(--line-sm); width: 8px; border-radius: var(--radius-lg); background: white; cursor: pointer; appearance: none; margin-top: 0px; } +input[type=range]::-moz-range-progress { background-color: var(--primary-500); height: 6px; border-radius: var(--radius-lg); } +::-webkit-scrollbar-track { background: #333333; } +::-webkit-scrollbar-thumb { background-color: var(--highlight-color); border-radius: var(--radius-lg); border-width: 0; box-shadow: 2px 2px 3px #111111; } +div.form { border-width: 0; box-shadow: none; background: transparent; overflow: visible; margin-bottom: 6px; } +div.compact { gap: 1em; } + +/* gradio style classes */ +fieldset .gr-block.gr-box, label.block span { padding: 0; margin-top: -4px; } +.border-2 { border-width: 0; } +.border-b-2 { border-bottom-width: 2px; border-color: var(--highlight-color) !important; padding-bottom: 2px; margin-bottom: 8px; } +.bg-white { color: lightyellow; background-color: var(--inactive-color); } +.gr-box { border-radius: var(--radius-sm) !important; background-color: #111111 !important; box-shadow: 2px 2px 3px #111111; border-width: 0; padding: 4px; margin: 12px 0px 12px 0px } +.gr-button { font-weight: normal; box-shadow: 2px 2px 3px #111111; font-size: 0.8rem; min-width: 32px; min-height: 32px; padding: 3px; margin: 3px; } +.gr-check-radio { background-color: var(--inactive-color); border-width: 0; border-radius: var(--radius-lg); box-shadow: 2px 2px 3px #111111; } +.gr-check-radio:checked { background-color: var(--highlight-color); } +.gr-compact { background-color: var(--background-color); } +.gr-form { border-width: 0; } +.gr-input { background-color: #333333 !important; padding: 4px; margin: 4px; } +.gr-input-label { color: lightyellow; border-width: 0; background: transparent; padding: 2px !important; } +.gr-panel { background-color: var(--background-color); } +.eta-bar { display: none !important } +svg.feather.feather-image, .feather .feather-image { display: none } +.gap-2 { padding-top: 8px; } +.gr-box > div > div > input.gr-text-input { right: 0; width: 4em; padding: 0; top: -12px; border: none; max-height: 20px; } +.output-html { line-height: 1.2rem; overflow-x: hidden; } +.output-html > div { margin-bottom: 8px; } +.overflow-hidden .flex .flex-col .relative col .gap-4 { min-width: var(--left-column); max-width: var(--left-column); } /* this is a problematic one */ +.p-2 { padding: 0; } +.px-4 { padding-lefT: 1rem; padding-right: 1rem; } +.py-6 { padding-bottom: 0; } +.tabs { background-color: var(--background-color); } +.block.token-counter span { background-color: var(--input-background-fill) !important; box-shadow: 2px 2px 2px #111; border: none !important; font-size: 0.8rem; } +.tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; } +div.tab-nav button.selected {background-color: var(--button-primary-background-fill);} +#settings div.tab-nav button.selected {background-color: var(--background-color); color: var(--primary-800); font-weight: bold;} +.label-wrap { background-color: #292b30; /* extension tab color*/ padding: 16px 8px 8px 8px; border-radius: var(--radius-lg); padding-left: 8px !important; } +.small-accordion .label-wrap { padding: 8px 0px 8px 0px; } +.small-accordion .label-wrap .icon { margin-right: 1em; } +.gradio-button.tool { border: none; box-shadow: none; border-radius: var(--radius-lg);} +button.selected {background: var(--button-primary-background-fill);} +.center.boundedheight.flex {background-color: var(--input-background-fill);} +.compact {border-radius: var(--border-radius-lg);} +#logMonitorData {background-color: var(--input-background-fill);} +#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; background-color: var(--primary-200); } +#tab_extensions table, #tab_config table { width: 96vw; } +#tab_extensions table input[type=checkbox] {appearance: none; border-radius: 0px;} +#tab_extensions button:hover { background-color: var(--button-secondary-background-fill-hover);} + +/* automatic style classes */ +.progressDiv { border-radius: var(--radius-sm) !important; position: fixed; top: 44px; right: 26px; max-width: 262px; height: 48px; z-index: 99; box-shadow: var(--button-shadow); } +.progressDiv .progress { border-radius: var(--radius-lg) !important; background: var(--highlight-color); line-height: 3rem; height: 48px; } +.gallery-item { box-shadow: none !important; } +.performance { color: #888; } +.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; } +.image-buttons { gap: 10px !important; justify-content: center; } +.image-buttons > button { max-width: 160px; } +.tooltip { background: var(--primary-800); color: white; border: none; border-radius: var(--radius-lg) } +#system_row > button, #settings_row > button, #config_row > button { max-width: 190px; } + +/* gradio elements overrides */ +#div.gradio-container { overflow-x: hidden; } +#img2img_label_copy_to_img2img { font-weight: normal; } +#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt { background-color: var(--background-color); box-shadow: 4px 4px 4px 0px #333333 !important; } +#txt2img_prompt > label > textarea, #txt2img_neg_prompt > label > textarea, #img2img_prompt > label > textarea, #img2img_neg_prompt > label > textarea { font-size: 1.1rem; } +#img2img_settings { min-width: calc(2 * var(--left-column)); max-width: calc(2 * var(--left-column)); background-color: #111111; padding-top: 16px; } +#interrogate, #deepbooru { margin: 0 0px 10px 0px; max-width: 80px; max-height: 80px; font-weight: normal; font-size: 0.95em; } +#quicksettings .gr-button-tool { font-size: 1.6rem; box-shadow: none; margin-top: -2px; height: 2.4em; } +#quicksettings button {padding: 0 0.5em 0.1em 0.5em;} +#open_folder_extras, #footer, #style_pos_col, #style_neg_col, #roll_col, #extras_upscaler_2, #extras_upscaler_2_visibility, #txt2img_seed_resize_from_w, #txt2img_seed_resize_from_h { display: none; } +#save-animation { border-radius: var(--radius-sm) !important; margin-bottom: 16px; background-color: #111111; } +#script_list { padding: 4px; margin-top: 16px; margin-bottom: 8px; } +#settings > div.flex-wrap { width: 15em; } +#txt2img_cfg_scale { min-width: 200px; } +#txt2img_checkboxes, #img2img_checkboxes { background-color: transparent; } +#txt2img_checkboxes, #img2img_checkboxes { margin-bottom: 0.2em; } +#txt2img_actions_column, #img2img_actions_column { flex-flow: wrap; justify-content: space-between; } +#txt2img_enqueue_wrapper, #img2img_enqueue_wrapper { min-width: unset; width: 48%; } +#txt2img_generate_box, #img2img_generate_box { min-width: unset; width: 48%; } + +#extras_upscale { margin-top: 10px } +#txt2img_progress_row > div { min-width: var(--left-column); max-width: var(--left-column); } +#txt2img_settings { min-width: var(--left-column); max-width: var(--left-column); background-color: #111111; padding-top: 16px; } +#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) } +#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; } +#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; } + +/* based on gradio built-in dark theme */ +:root, .light, .dark { + --body-background-fill: var(--background-color); + --color-accent-soft: var(--neutral-700); + --background-fill-secondary: none; + --border-color-accent: var(--background-color); + --border-color-primary: var(--background-color); + --link-text-color-active: var(--primary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --shadow-spread: 1px; + --block-background-fill: None; + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--primary-400); + --checkbox-background-color-focus: var(--primary-700); + --checkbox-background-color-hover: var(--primary-700); + --checkbox-background-color-selected: var(--primary-500); + --checkbox-border-color: transparent; + --checkbox-border-color-focus: var(--primary-800); + --checkbox-border-color-hover: var(--primary-800); + --checkbox-border-color-selected: var(--primary-800); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: None; + --checkbox-label-background-fill-hover: None; + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error-text-color: #f768b7; /*was ef4444*/ + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--background-color); + --input-border-color-focus: var(--primary-800); + --input-placeholder-color: var(--neutral-500); + --input-shadow-focus: None; + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-800)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--primary-300); + --table-odd-background-fill: var(--primary-200); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: var(--primary-500); + --button-primary-background-fill-hover: var(--primary-800); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --secondary-50: #eff6ff; + --secondary-100: #dbeafe; + --secondary-200: #bfdbfe; + --secondary-300: #93c5fd; + --secondary-400: #60a5fa; + --secondary-500: #3b82f6; + --secondary-600: #2563eb; + --secondary-700: #1d4ed8; + --secondary-800: #1e40af; + --secondary-900: #1e3a8a; + --secondary-950: #1d3660; + --neutral-50: #f0f0f0; /* */ + --neutral-100: #e0dedc;/* majority of text (neutral gray yellow) */ + --neutral-200: #d0d0d0; + --neutral-300: #9d9dab; /* top tab text (light accent) */ + --neutral-400: #ffba85;/* tab title (light beige) */ + --neutral-500: #484746; /* prompt text (desat accent)*/ + --neutral-600: #605a54; /* tab outline color (accent color)*/ + --neutral-700: #1b1c1e; /* small settings tab accent (dark)*/ + --neutral-800: #e75d5d; /* bright orange accent */ + --neutral-900: #111827; + --neutral-950: #0b0f19; + --radius-xxs: 0; + --radius-xs: 0; + --radius-md: 0; + --radius-xl: 0; + --radius-xxl: 0; + --body-text-size: var(--text-md); + --body-text-weight: 400; + --embed-radius: var(--radius-lg); + --color-accent: var(--primary-500); + --shadow-drop: 0; + --shadow-drop-lg: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1); + --shadow-inset: rgba(0,0,0,0.05) 0px 2px 4px 0px inset; + --block-border-width: 1px; + --block-info-text-size: var(--text-sm); + --block-info-text-weight: 400; + --block-label-border-width: 1px; + --block-label-margin: 0; + --block-label-padding: var(--spacing-sm) var(--spacing-lg); + --block-label-radius: calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px) 0; + --block-label-right-radius: 0 calc(var(--radius-lg) - 1px) 0 calc(var(--radius-lg) - 1px); + --block-label-text-size: var(--text-sm); + --block-label-text-weight: 400; + --block-padding: var(--spacing-xl) calc(var(--spacing-xl) + 2px); + --block-radius: var(--radius-lg); + --block-shadow: var(--shadow-drop); + --block-title-background-fill: none; + --block-title-border-color: none; + --block-title-border-width: 0; + --block-title-padding: 0; + --block-title-radius: none; + --block-title-text-size: var(--text-md); + --block-title-text-weight: 400; + --container-radius: var(--radius-lg); + --form-gap-width: 1px; + --layout-gap: var(--spacing-xxl); + --panel-border-width: 0; + --section-header-text-size: var(--text-md); + --section-header-text-weight: 400; + --checkbox-border-radius: var(--radius-sm); + --checkbox-label-gap: 2px; + --checkbox-label-padding: var(--spacing-md); + --checkbox-label-shadow: var(--shadow-drop); + --checkbox-label-text-size: var(--text-md); + --checkbox-label-text-weight: 400; + --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e"); + --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e"); + --checkbox-shadow: var(--input-shadow); + --error-border-width: 1px; + --input-border-width: 1px; + --input-radius: var(--radius-lg); + --input-text-size: var(--text-md); + --input-text-weight: 400; + --loader-color: var(--color-accent); + --prose-text-size: var(--text-md); + --prose-text-weight: 400; + --prose-header-text-weight: 600; + --slider-color: ; + --table-radius: var(--radius-lg); + --button-large-padding: 2px 6px; + --button-large-radius: var(--radius-lg); + --button-large-text-size: var(--text-lg); + --button-large-text-weight: 400; + --button-shadow: none; + --button-shadow-active: none; + --button-shadow-hover: none; + --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm)); + --button-small-radius: var(--radius-lg); + --button-small-text-size: var(--text-md); + --button-small-text-weight: 400; + --button-transition: none; + --size-9: 64px; + --size-14: 64px; +} diff --git a/javascript/ui.js b/javascript/ui.js index daffbe352..7902aaacf 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -66,38 +66,54 @@ function extract_image_from_gallery(gallery) { window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around +function switchToTab(tab) { + const tabs = Array.from(gradioApp().querySelectorAll('#tabs > .tab-nav > button')); + const btn = tabs?.find((t) => t.innerText === tab); + log('switchToTab', tab); + if (btn) btn.click(); +} + function switch_to_txt2img(...args) { - gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click(); + switchToTab('Text'); return Array.from(arguments); } function switch_to_img2img_tab(no) { - gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); + switchToTab('Image'); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click(); } function switch_to_img2img(...args) { + switchToTab('Image'); switch_to_img2img_tab(0); return Array.from(arguments); } function switch_to_sketch(...args) { + switchToTab('Image'); switch_to_img2img_tab(1); return Array.from(arguments); } function switch_to_inpaint(...args) { + switchToTab('Image'); switch_to_img2img_tab(2); return Array.from(arguments); } function switch_to_inpaint_sketch(...args) { + switchToTab('Image'); switch_to_img2img_tab(3); return Array.from(arguments); } function switch_to_extras(...args) { - gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click(); + switchToTab('Process'); + return Array.from(arguments); +} + +function switch_to_control(...args) { + switchToTab('Control'); return Array.from(arguments); } @@ -164,6 +180,17 @@ function submit_img2img(...args) { return res; } +function submit_control(...args) { + log('submitControl'); + clearGallery('control'); + const id = randomId(); + requestProgress(id, null, gradioApp().getElementById('control_gallery')); + const res = create_submit_args(args); + res[0] = id; + res[1] = gradioApp().querySelector('#control-tabs > .tab-nav > .selected')?.innerText.toLowerCase() || ''; // selected tab name + return res; +} + function submit_postprocessing(...args) { log('SubmitExtras'); clearGallery('extras'); @@ -211,6 +238,12 @@ function recalculate_prompts_inpaint(...args) { return Array.from(arguments); } +function recalculate_prompts_control(...args) { + recalculatePromptTokens('control_prompt'); + recalculatePromptTokens('control_neg_prompt'); + return Array.from(arguments); +} + function registerDragDrop() { const qs = gradioApp().getElementById('quicksettings'); if (!qs) return; @@ -279,6 +312,8 @@ onAfterUiUpdate(async () => { registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); + registerTextarea('control_prompt', 'control_token_counter', 'control_token_button'); + registerTextarea('control_neg_prompt', 'control_negative_token_counter', 'control_negative_token_button'); }); function update_txt2img_tokens(...args) { diff --git a/launch.py b/launch.py index a8c156b8c..4617e5f60 100755 --- a/launch.py +++ b/launch.py @@ -47,6 +47,16 @@ def get_custom_args(): if current != default: custom[arg] = getattr(args, arg) installer.log.info(f'Command line args: {sys.argv[1:]} {installer.print_dict(custom)}') + if os.environ.get('SD_ENV_DEBUG', None) is not None: + env = os.environ.copy() + if 'PATH' in env: + del env['PATH'] + if 'PS1' in env: + del env['PS1'] + installer.log.trace(f'Environment: {installer.print_dict(env)}') + else: + env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')] + installer.log.debug(f'Env flags: {env}') @lru_cache() @@ -205,7 +215,6 @@ def start_server(immediate=True, server=None): installer.log.info('Startup: standard') installer.install_requirements() installer.install_packages() - installer.install_repositories() installer.install_submodules() init_paths() installer.install_extensions() diff --git a/models/Reference/amused--amused-256.jpg b/models/Reference/amused--amused-256.jpg new file mode 100644 index 000000000..f410817a8 Binary files /dev/null and b/models/Reference/amused--amused-256.jpg differ diff --git a/models/Reference/amused--amused-512.jpg b/models/Reference/amused--amused-512.jpg new file mode 100644 index 000000000..0b8e26240 Binary files /dev/null and b/models/Reference/amused--amused-512.jpg differ diff --git a/models/Reference/damo-vilab--text-to-video-ms-1.7b.jpg b/models/Reference/damo-vilab--text-to-video-ms-1.7b.jpg new file mode 100644 index 000000000..39ad83eea Binary files /dev/null and b/models/Reference/damo-vilab--text-to-video-ms-1.7b.jpg differ diff --git a/models/Reference/dreamshaperXL_turboDpmppSDE.jpg b/models/Reference/dreamshaperXL_turboDpmppSDE.jpg new file mode 100644 index 000000000..802d92593 Binary files /dev/null and b/models/Reference/dreamshaperXL_turboDpmppSDE.jpg differ diff --git a/models/Reference/dreamshaper_8.jpg b/models/Reference/dreamshaper_8.jpg new file mode 100644 index 000000000..369d540b3 Binary files /dev/null and b/models/Reference/dreamshaper_8.jpg differ diff --git a/models/Reference/juggernautXL_v7Rundiffusion.jpg b/models/Reference/juggernautXL_v7Rundiffusion.jpg new file mode 100644 index 000000000..cbce7cb32 Binary files /dev/null and b/models/Reference/juggernautXL_v7Rundiffusion.jpg differ diff --git a/models/Reference/juggernaut_reborn.jpg b/models/Reference/juggernaut_reborn.jpg new file mode 100644 index 000000000..f19b294e9 Binary files /dev/null and b/models/Reference/juggernaut_reborn.jpg differ diff --git a/models/Reference/playgroundai--playground-v1.jpg b/models/Reference/playgroundai--playground-v1.jpg new file mode 100644 index 000000000..5a2cf4b5b Binary files /dev/null and b/models/Reference/playgroundai--playground-v1.jpg differ diff --git a/models/Reference/playgroundai--playground-v2-1024px-aesthetic.jpg b/models/Reference/playgroundai--playground-v2-1024px-aesthetic.jpg new file mode 100644 index 000000000..f1eca8c9d Binary files /dev/null and b/models/Reference/playgroundai--playground-v2-1024px-aesthetic.jpg differ diff --git a/models/Reference/playgroundai--playground-v2-256px-base.jpg b/models/Reference/playgroundai--playground-v2-256px-base.jpg new file mode 100644 index 000000000..42a332b91 Binary files /dev/null and b/models/Reference/playgroundai--playground-v2-256px-base.jpg differ diff --git a/models/Reference/playgroundai--playground-v2-512px-base.jpg b/models/Reference/playgroundai--playground-v2-512px-base.jpg new file mode 100644 index 000000000..759945df4 Binary files /dev/null and b/models/Reference/playgroundai--playground-v2-512px-base.jpg differ diff --git a/models/Reference/salesforce--blipdiffusion.jpg b/models/Reference/salesforce--blipdiffusion.jpg new file mode 100644 index 000000000..79155aba9 Binary files /dev/null and b/models/Reference/salesforce--blipdiffusion.jpg differ diff --git a/models/Reference/segmind--Segmind-Vega.jpg b/models/Reference/segmind--Segmind-Vega.jpg new file mode 100644 index 000000000..8356b0ae8 Binary files /dev/null and b/models/Reference/segmind--Segmind-Vega.jpg differ diff --git a/models/Reference/stabilityai--stable-video-diffusion-img2vid-xt.jpg b/models/Reference/stabilityai--stable-video-diffusion-img2vid-xt.jpg index 81a809695..822d1cd9c 100644 Binary files a/models/Reference/stabilityai--stable-video-diffusion-img2vid-xt.jpg and b/models/Reference/stabilityai--stable-video-diffusion-img2vid-xt.jpg differ diff --git a/models/Reference/stabilityai--stable-video-diffusion-img2vid.jpg b/models/Reference/stabilityai--stable-video-diffusion-img2vid.jpg index 81a809695..1c3513ae4 100644 Binary files a/models/Reference/stabilityai--stable-video-diffusion-img2vid.jpg and b/models/Reference/stabilityai--stable-video-diffusion-img2vid.jpg differ diff --git a/modules/api/api.py b/modules/api/api.py index ef0825926..1f60e5c01 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -13,7 +13,7 @@ import piexif import piexif.helper import gradio as gr -from modules import errors, shared, sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import errors, shared, sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, script_callbacks, generation_parameters_copypaste from modules.sd_vae import vae_dict from modules.api import models from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -133,7 +133,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - self.add_api_route("/sdapi/v1/refresh-vaes", self.refresh_vaes, methods=["POST"]) + self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vaes, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) @@ -145,6 +145,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem]) self.add_api_route("/sdapi/v1/log", self.get_log_buffer, methods=["GET"], response_model=List) self.add_api_route("/sdapi/v1/start", self.session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/motd", self.get_motd, methods=["GET"], response_model=str) @@ -287,14 +288,14 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): p.scripts = script_runner p.outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids p.outpath_samples = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples - shared.state.begin('api-txt2img') + shared.state.begin('api-txt2img', api=True) script_args = self.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) if selectable_scripts is not None: processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) - shared.state.end() + shared.state.end(api=False) b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -335,14 +336,14 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): p.scripts = script_runner p.outpath_grids = shared.opts.outdir_img2img_grids p.outpath_samples = shared.opts.outdir_img2img_samples - shared.state.begin('api-img2img') + shared.state.begin('api-img2img', api=True) script_args = self.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) if selectable_scripts is not None: processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) - shared.state.end() + shared.state.end(api=False) b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: @@ -368,14 +369,22 @@ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): def pnginfoapi(self, req: models.PNGInfoRequest): if not req.image.strip(): return models.PNGInfoResponse(info="") + image = decode_base64_to_image(req.image.strip()) if image is None: return models.PNGInfoResponse(info="") + geninfo, items = images.read_info_from_image(image) if geninfo is None: geninfo = "" - items = {**{'parameters': geninfo}, **items} - return models.PNGInfoResponse(info=geninfo, items=items) + + if items and items['parameters']: + del items['parameters'] + + params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + script_callbacks.infotext_pasted_callback(geninfo, params) + + return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) def progressapi(self, req: models.ProgressRequest = Depends()): if shared.state.job_count == 0: @@ -464,7 +473,7 @@ def get_upscalers(self): return [{"name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, "model_url": None, "scale": upscaler.scale} for upscaler in shared.sd_upscalers] def get_sd_models(self): - return [{"title": x.title, "name": x.name, "filename": x.filename, "type": x.type, "hash": x.shorthash, "sha256": x.sha256, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.name, "filename": x.filename, "type": x.type, "hash": x.shorthash, "sha256": x.sha256, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -640,6 +649,25 @@ def get_memory(self): cuda = { 'error': f'{err}' } return models.MemoryResponse(ram = ram, cuda = cuda) + def get_extensions_list(self): + from modules import extensions + extensions.list_extensions() + ext_list = [] + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info() + if ext.remote is not None: + ext_list.append({ + "name": ext.name, + "remote": ext.remote, + "branch": ext.branch, + "commit_hash":ext.commit_hash, + "commit_date":ext.commit_date, + "version":ext.version, + "enabled":ext.enabled + }) + return ext_list + def launch(self): config = { "listen": shared.cmd_opts.listen, diff --git a/modules/api/models.py b/modules/api/models.py index 3ed73a42e..9d90f0ca7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -170,7 +170,8 @@ class PNGInfoRequest(BaseModel): class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with the parameters used to generate the image") - items: dict = Field(title="Items", description="An object containing all the info the image had") + items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") + parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") class LogRequest(BaseModel): lines: int = Field(default=100, title="Lines", description="How many lines to return") @@ -209,7 +210,7 @@ class PreprocessResponse(BaseModel): if metadata is not None: fields.update({key: (Optional[optType], Field( - default=metadata.default ,description=metadata.label))}) + default=metadata.default, description=metadata.label))}) else: fields.update({key: (Optional[optType], Field())}) @@ -245,7 +246,7 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") - name: str = Field(title="Model Name") + model_name: str = Field(title="Model Name") filename: str = Field(title="Filename") type: str = Field(title="Model type") sha256: Optional[str] = Field(title="SHA256 hash") @@ -286,7 +287,6 @@ class ExtraNetworkItem(BaseModel): # metadata: Optional[Any] = Field(title="Metadata") # local: Optional[str] = Field(title="Local") - class ArtistItem(BaseModel): name: str = Field(title="Name") score: float = Field(title="Score") @@ -311,7 +311,6 @@ class ScriptsList(BaseModel): txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)") img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)") - class ScriptArg(BaseModel): label: str = Field(default=None, title="Label", description="Name of the argument in UI") value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument") @@ -320,9 +319,17 @@ class ScriptArg(BaseModel): step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument") - class ScriptInfo(BaseModel): name: str = Field(default=None, title="Name", description="Script name") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + +class ExtensionItem(BaseModel): + name: str = Field(title="Name", description="Extension name") + remote: str = Field(title="Remote", description="Extension Repository URL") + branch: str = Field(title="Branch", description="Extension Repository Branch") + commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") + version: str = Field(title="Version", description="Extension Version") + commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date") + enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") diff --git a/modules/call_queue.py b/modules/call_queue.py index 21f94352d..482a87e82 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -76,7 +76,11 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): if not shared.mem_mon.disabled: vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()} if vram.get('active_peak', 0) > 0: - vram_html = f" |

GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB | retries {vram['retries']} oom {vram['oom']}

" - res[-1] += f"

Time: {elapsed_text}

{vram_html}
" + vram_html = " |

" + vram_html += f"GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB" + vram_html += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else '' + vram_html += "

" + if isinstance(res, list): + res[-1] += f"

Time: {elapsed_text}

{vram_html}
" return tuple(res) return f diff --git a/modules/control/proc/__init__.py b/modules/control/proc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/canny.py b/modules/control/proc/canny.py new file mode 100644 index 000000000..1f450b9f6 --- /dev/null +++ b/modules/control/proc/canny.py @@ -0,0 +1,36 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from modules.control.util import HWC3, resize_image + +class CannyDetector: + def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + detected_map = cv2.Canny(input_image, low_threshold, high_threshold) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + detected_map = detected_map.convert('L') + + return detected_map diff --git a/modules/control/proc/dwpose/__init__.py b/modules/control/proc/dwpose/__init__.py new file mode 100644 index 000000000..a0c5c513b --- /dev/null +++ b/modules/control/proc/dwpose/__init__.py @@ -0,0 +1,90 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import cv2 +import numpy as np +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .draw import draw_bodypose, draw_handpose, draw_facepose + + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + canvas = draw_bodypose(canvas, candidate, subset) + canvas = draw_handpose(canvas, hands) + canvas = draw_facepose(canvas, faces) + + return canvas + +class DWposeDetector: + def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): + from .wholebody import Wholebody + + self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) + + def to(self, device): + self.pose_estimation.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", min_confidence=0.3, **kwargs): + input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + H, W, _C = input_image.shape + + candidate, subset = self.pose_estimation(input_image) + if candidate is None: + return Image.fromarray(input_image) + nums, _keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > min_confidence: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset < min_confidence + candidate[un_visible] = -1 + + _foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + detected_map = draw_pose(pose, H, W) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/dwpose/config/dwpose-l_384x288.py b/modules/control/proc/dwpose/config/dwpose-l_384x288.py new file mode 100644 index 000000000..e054be72e --- /dev/null +++ b/modules/control/proc/dwpose/config/dwpose-l_384x288.py @@ -0,0 +1,257 @@ +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(288, 384), + sigma=(6., 6.93), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(9, 12), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = '/data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/modules/control/proc/dwpose/config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py b/modules/control/proc/dwpose/config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py new file mode 100644 index 000000000..3c89683f0 --- /dev/null +++ b/modules/control/proc/dwpose/config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py @@ -0,0 +1,259 @@ +# _base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(288, 384), + sigma=(6., 6.93), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(9, 12), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = 'data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/modules/control/proc/dwpose/config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py b/modules/control/proc/dwpose/config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py new file mode 100644 index 000000000..20adeea96 --- /dev/null +++ b/modules/control/proc/dwpose/config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py @@ -0,0 +1,259 @@ +# _base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(192, 256), + sigma=(4.9, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.67, + widen_factor=0.75, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=768, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(6, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = 'data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/modules/control/proc/dwpose/config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py b/modules/control/proc/dwpose/config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py new file mode 100644 index 000000000..d97c2f78b --- /dev/null +++ b/modules/control/proc/dwpose/config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py @@ -0,0 +1,259 @@ +# _base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(192, 256), + sigma=(4.9, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.167, + widen_factor=0.375, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=384, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(6, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = 'data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/modules/control/proc/dwpose/config/yolox_l_8xb8-300e_coco.py b/modules/control/proc/dwpose/config/yolox_l_8xb8-300e_coco.py new file mode 100644 index 000000000..7b4cb5a4b --- /dev/null +++ b/modules/control/proc/dwpose/config/yolox_l_8xb8-300e_coco.py @@ -0,0 +1,245 @@ +img_scale = (640, 640) # width, height + +# model settings +model = dict( + type='YOLOX', + data_preprocessor=dict( + type='DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + backbone=dict( + type='CSPDarknet', + deepen_factor=1.0, + widen_factor=1.0, + out_indices=(2, 3, 4), + use_depthwise=False, + spp_kernal_sizes=(5, 9, 13), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + ), + neck=dict( + type='YOLOXPAFPN', + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=256, + feat_channels=256, + stacked_convs=2, + strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +# dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + # Resize and Pad are for the last 15 epochs when Mosaic, + # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +train_dataset = dict( + # use MultiImageMixDataset wrapper to support mosaic and mixup + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_cfg=dict(filter_empty_gt=False, min_size=32), + backend_args=backend_args), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) +val_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training settings +max_epochs = 300 +num_last_epochs = 15 +interval = 10 + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +# optimizer +# default 8 gpu +base_lr = 0.01 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, + nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) + +# learning rate +param_scheduler = [ + dict( + # use quadratic formula to warm up 5 epochs + # and lr is updated by iteration + # TODO: fix default scope in get function + type='mmdet.QuadraticWarmupLR', + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + # use cosine lr from 5 to 285 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + # use fixed lr during last 15 epochs + type='ConstantLR', + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ) +] + +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict(type='SyncNormHook', priority=48), + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + priority=49) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/modules/control/proc/dwpose/draw.py b/modules/control/proc/dwpose/draw.py new file mode 100644 index 000000000..dbccc3af1 --- /dev/null +++ b/modules/control/proc/dwpose/draw.py @@ -0,0 +1,307 @@ +import math +import numpy as np +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, _C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + import matplotlib as mpl + + H, W, _C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + # (person_number*2, 21, 2) + for i in range(len(all_hand_peaks)): + peaks = all_hand_peaks[i] + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for _, keyponit in enumerate(peaks): + x, y = keyponit + + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, _C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: + x = 0 + if y < 0: + y = 0 + width1 = width + width2 = width + if x + width > image_width: + width1 = image_width - x + if y + width > image_height: + width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/modules/control/proc/dwpose/wholebody.py b/modules/control/proc/dwpose/wholebody.py new file mode 100644 index 000000000..356044c35 --- /dev/null +++ b/modules/control/proc/dwpose/wholebody.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import numpy as np +from modules.shared import log + +mmok = True + +try: + import mmcv # pylint: disable=unused-import +except ImportError as e: + mmok = False + log.error(f"Control processor DWPose: {e}") +try: + from mmpose.apis import inference_topdown + from mmpose.apis import init_model as init_pose_estimator + from mmpose.evaluation.functional import nms + from mmpose.utils import adapt_mmdet_pipeline + from mmpose.structures import merge_data_samples +except ImportError as e: + mmok = False + log.error(f"Control processor DWPose: {e}") + +try: + from mmdet.apis import inference_detector, init_detector +except ImportError as e: + mmok = False + log.error(f"Control processor DWPose: {e}") + + def inference_detector(*args, **kwargs): + return lambda *args, **kwargs: None + +if not mmok: + log.error('Control processor DWPose: OpenMMLab is not installed') + + +class Wholebody: + def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): + if not mmok: + self.detector = lambda *args, **kwargs: None + return None + prefix = os.path.dirname(__file__) + if det_config is None: + det_config = "config/yolox_l_8xb8-300e_coco.py" + if pose_config is None: + pose_config = "config/dwpose-l_384x288.py" + if not det_config.startswith('prefix'): + det_config = os.path.join(prefix, det_config) + if not pose_config.startswith('prefix'): + pose_config = os.path.join(prefix, pose_config) + if det_ckpt is None: + det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' + if pose_ckpt is None: + pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" + # build detector + self.detector = init_detector(det_config, det_ckpt, device=device) + self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) + # build pose estimator + self.pose_estimator = init_pose_estimator( + pose_config, + pose_ckpt, + device=device) + + def to(self, device): + self.detector.to(device) + self.pose_estimator.to(device) + return self + + def __call__(self, oriImg): + if not mmok: + return None, None + # predict bbox + det_result = inference_detector(self.detector, oriImg) + pred_instance = det_result.pred_instances.cpu().numpy() + bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) + bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.5)] + # set NMS threshold + bboxes = bboxes[nms(bboxes, 0.7), :4] + # predict keypoints + if len(bboxes) == 0: + pose_results = inference_topdown(self.pose_estimator, oriImg) + else: + pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) + preds = merge_data_samples(pose_results) + preds = preds.pred_instances + # preds = pose_results[0].pred_instances + keypoints = preds.get('transformed_keypoints', preds.keypoints) + if 'keypoint_scores' in preds: + scores = preds.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + if 'keypoints_visible' in preds: + visible = preds.keypoints_visible + else: + visible = np.ones(keypoints.shape[:-1]) + keypoints_info = np.concatenate( + (keypoints, scores[..., None], visible[..., None]), + axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3] + openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17] + new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + keypoints, scores, visible = keypoints_info[..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + return keypoints, scores diff --git a/modules/control/proc/edge.py b/modules/control/proc/edge.py new file mode 100644 index 000000000..73481129f --- /dev/null +++ b/modules/control/proc/edge.py @@ -0,0 +1,64 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from modules.control.util import HWC3, resize_image + +ed = None +""" + PFmode: bool + EdgeDetectionOperator: int + GradientThresholdValue: int + AnchorThresholdValue: int + ScanInterval: int + MinPathLength: int + Sigma: float + SumFlag: bool + NFAValidation: bool + MinLineLength: int + MaxDistanceBetweenTwoLines: float + LineFitErrorThreshold: float + MaxErrorThreshold: float +""" + +class EdgeDetector: + def __call__(self, input_image=None, pf=True, mode='edge', detect_resolution=512, image_resolution=512, output_type=None, **kwargs): + global ed # pylint: disable=global-statement + if ed is None: + ed = cv2.ximgproc.createEdgeDrawing() + params = cv2.ximgproc.EdgeDrawing.Params() + params.PFmode = pf + ed.setParams(params) + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + img_gray = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) + edges = ed.detectEdges(img_gray) + if mode == 'edge': + edge_map = ed.getEdgeImage(edges) + else: + edge_map = ed.getGradientImage(edges) + edge_map = np.expand_dims(edge_map, axis=2) + edge_map = cv2.cvtColor(edge_map, cv2.COLOR_GRAY2BGR).astype(np.uint8) + edge_map = HWC3(edge_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + edge_map = cv2.resize(edge_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + edge_map = Image.fromarray(edge_map) + edge_map = edge_map.convert('L') + + return edge_map diff --git a/modules/control/proc/hed.py b/modules/control/proc/hed.py new file mode 100644 index 000000000..9504e627e --- /dev/null +++ b/modules/control/proc/hed.py @@ -0,0 +1,128 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, nms, resize_image, safe_step + + +class DoubleConvBlock(torch.nn.Module): # pylint: disable=abstract-method + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for _i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): # pylint: disable=abstract-method + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + +class HEDdetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None): + filename = filename or "ControlNetHED.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + netNetwork = ControlNetHED_Apache2() + netNetwork.load_state_dict(torch.load(model_path, map_location='cpu')) + netNetwork.float().eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.netNetwork.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + H, W, _C = input_image.shape + image_hed = torch.from_numpy(input_image.copy()).float().to(device) + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/leres/__init__.py b/modules/control/proc/leres/__init__.py new file mode 100644 index 000000000..3a62882bd --- /dev/null +++ b/modules/control/proc/leres/__init__.py @@ -0,0 +1,108 @@ +import os + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .leres.depthmap import estimateboost, estimateleres +from .leres.multi_depth_model_woauxi import RelDepthModel +from .leres.net_tools import strip_prefix_if_present +from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel +from .pix2pix.options.test_options import TestOptions + + +class LeresDetector: + def __init__(self, model, pix2pixmodel): + self.model = model + self.pix2pixmodel = pix2pixmodel + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None): + filename = filename or "res101.pth" + pix2pix_filename = pix2pix_filename or "latest_net_G.pth" + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + checkpoint = torch.load(model_path, map_location=torch.device('cpu')) + model = RelDepthModel(backbone='resnext101') + model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) + del checkpoint + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, pix2pix_filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir) + opt = TestOptions().parse() + if not torch.cuda.is_available(): + opt.gpu_ids = [] # cpu mode + pix2pixmodel = Pix2Pix4DepthModel(opt) + pix2pixmodel.save_dir = os.path.dirname(model_path) + pix2pixmodel.load_networks('latest') + pix2pixmodel.eval() + return cls(model, pix2pixmodel) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"): + # device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + height, width, _dim = input_image.shape + + if boost: + depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height)) + else: + depth = estimateleres(input_image, self.model, width, height) + + numbytes=2 + depth_min = depth.min() + depth_max = depth.max() + max_val = (2**(8*numbytes))-1 + + # check output before normalizing and mapping to 16 bit + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape) + + # single channel, 16 bit image + depth_image = out.astype("uint16") + + # convert to uint8 + depth_image = cv2.convertScaleAbs(depth_image, alpha=255.0/65535.0) + + # remove near + if thr_a != 0: + thr_a = thr_a/100*255 + depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1] + + # invert image + depth_image = cv2.bitwise_not(depth_image) + + # remove bg + if thr_b != 0: + thr_b = thr_b/100*255 + depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1] + + detected_map = depth_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/leres/leres/LICENSE b/modules/control/proc/leres/leres/LICENSE new file mode 100644 index 000000000..e0f1d07d9 --- /dev/null +++ b/modules/control/proc/leres/leres/LICENSE @@ -0,0 +1,23 @@ +https://github.com/thygate/stable-diffusion-webui-depthmap-script + +MIT License + +Copyright (c) 2023 Bob Thiry + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/modules/control/proc/leres/leres/Resnet.py b/modules/control/proc/leres/leres/Resnet.py new file mode 100644 index 000000000..c9041b187 --- /dev/null +++ b/modules/control/proc/leres/leres/Resnet.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn as NN + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + #self.avgpool = nn.AvgPool2d(7, stride=1) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + features = [] + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + x = self.layer2(x) + features.append(x) + x = self.layer3(x) + features.append(x) + x = self.layer4(x) + features.append(x) + + return features + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def resnet34(pretrained=True, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def resnet101(pretrained=True, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def resnet152(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model diff --git a/modules/control/proc/leres/leres/Resnext_torch.py b/modules/control/proc/leres/leres/Resnext_torch.py new file mode 100644 index 000000000..9af54fcc3 --- /dev/null +++ b/modules/control/proc/leres/leres/Resnext_torch.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# coding: utf-8 +import torch.nn as nn + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + +__all__ = ['resnext101_32x8d'] + + +model_urls = { + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + features = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + #x = self.avgpool(x) + #x = torch.flatten(x, 1) + #x = self.fc(x) + + return features + + def forward(self, x): + return self._forward_impl(x) + + + +def resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + diff --git a/modules/control/proc/leres/leres/__init__.py b/modules/control/proc/leres/leres/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/leres/leres/depthmap.py b/modules/control/proc/leres/leres/depthmap.py new file mode 100644 index 000000000..d89ac7e7a --- /dev/null +++ b/modules/control/proc/leres/leres/depthmap.py @@ -0,0 +1,546 @@ +# Author: thygate +# https://github.com/thygate/stable-diffusion-webui-depthmap-script + +import gc +from operator import getitem + +import cv2 +import numpy as np +import skimage.measure +import torch +from torchvision.transforms import transforms + +from modules.control.util import torch_gc + +whole_size_threshold = 1600 # R_max from the paper +pix2pixsize = 1024 + +def scale_torch(img): + """ + Scale the image and output it in torch.tensor. + :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W] + :param scale: the scale factor. float + :return: img. [C, H, W] + """ + if len(img.shape) == 2: + img = img[np.newaxis, :, :] + if img.shape[2] == 3: + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )]) + img = transform(img.astype(np.float32)) + else: + img = img.astype(np.float32) + img = torch.from_numpy(img) + return img + +def estimateleres(img, model, w, h): + device = next(iter(model.parameters())).device + # leres transform input + rgb_c = img[:, :, ::-1].copy() + A_resize = cv2.resize(rgb_c, (w, h)) + img_torch = scale_torch(A_resize)[None, :, :, :] + + # compute + img_torch = img_torch.to(device) + prediction = model.depth_model(img_torch) + + prediction = prediction.squeeze().cpu().numpy() + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + return prediction + +def generatemask(size): + # Generates a Guassian mask + mask = np.zeros(size, dtype=np.float32) + sigma = int(size[0]/16) + k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) + mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 + mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) + mask = (mask - mask.min()) / (mask.max() - mask.min()) + mask = mask.astype(np.float32) + return mask + +def resizewithpool(img, size): + i_size = img.shape[0] + n = int(np.floor(i_size/size)) + + out = skimage.measure.block_reduce(img, (n, n), np.max) + return out + +def rgb2gray(rgb): + # Converts rgb to gray + return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]) + +def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000): + # Returns the R_x resolution described in section 5 of the main paper. + + # Parameters: + # img :input rgb image + # basesize : size the dilation kernel which is equal to receptive field of the network. + # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue. + # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3. + # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper) + + # Returns: + # outputsize_scale*speed_scale :The computed R_x resolution + # patch_scale: K parameter from section 6 of the paper + + # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search + speed_scale = 32 + image_dim = int(min(img.shape[0:2])) + + gray = rgb2gray(img) + grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)) + grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA) + + # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues + m = grad.min() + M = grad.max() + middle = m + (0.4 * (M - m)) + grad[grad < middle] = 0 + grad[grad >= middle] = 1 + + # dilation kernel with size of the receptive field + kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float) + # dilation kernel with size of the a quarter of receptive field used to compute k + # as described in section 6 of main paper + kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float) + + # Output resolution limit set by the whole_size_threshold and scale_threshold. + threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2])) + + outputsize_scale = basesize / speed_scale + for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))): + grad_resized = resizewithpool(grad, p_size) + grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST) + grad_resized[grad_resized >= 0.5] = 1 + grad_resized[grad_resized < 0.5] = 0 + + dilated = cv2.dilate(grad_resized, kernel, iterations=1) + meanvalue = (1-dilated).mean() + if meanvalue > confidence: + break + else: + outputsize_scale = p_size + + grad_region = cv2.dilate(grad_resized, kernel2, iterations=1) + patch_scale = grad_region.mean() + + return int(outputsize_scale*speed_scale), patch_scale + +# Generate a double-input depth estimation +def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel): + # Generate the low resolution estimation + estimate1 = singleestimate(img, size1, model, net_type) + # Resize to the inference size of merge network. + estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Generate the high resolution estimation + estimate2 = singleestimate(img, size2, model, net_type) + # Resize to the inference size of merge network. + estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Inference on the merge model + pix2pixmodel.set_input(estimate1, estimate2) + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / ( + torch.max(prediction_mapped) - torch.min(prediction_mapped)) + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + return prediction_mapped + +# Generate a single-input depth estimation +def singleestimate(img, msize, model, net_type): + # if net_type == 0: + return estimateleres(img, model, msize, msize) + # else: + # return estimatemidasBoost(img, model, msize, msize) + +def applyGridpatch(blsize, stride, img, box): + # Extract a simple grid patch. + counter1 = 0 + patch_bound_list = {} + for k in range(blsize, img.shape[1] - blsize, stride): + for j in range(blsize, img.shape[0] - blsize, stride): + patch_bound_list[str(counter1)] = {} + patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize] + patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1], + patchbounds[2] - patchbounds[0]] + patch_bound_list[str(counter1)]['rect'] = patch_bound + patch_bound_list[str(counter1)]['size'] = patch_bound[2] + counter1 = counter1 + 1 + return patch_bound_list + +# Generating local patches to perform the local refinement described in section 6 of the main paper. +def generatepatchs(img, base_size): + + # Compute the gradients as a proxy of the contextual cues. + img_gray = rgb2gray(img) + whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\ + np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)) + + threshold = whole_grad[whole_grad > 0].mean() + whole_grad[whole_grad < threshold] = 0 + + # We use the integral image to speed-up the evaluation of the amount of gradients for each patch. + gf = whole_grad.sum()/len(whole_grad.reshape(-1)) + grad_integral_image = cv2.integral(whole_grad) + + # Variables are selected such that the initial patch size would be the receptive field size + # and the stride is set to 1/3 of the receptive field size. + blsize = int(round(base_size/2)) + stride = int(round(blsize*0.75)) + + # Get initial Grid + patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0]) + + # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine + # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map. + patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf) + + # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest + # patch + patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True) + return patchset + +def getGF_fromintegral(integralimage, rect): + # Computes the gradient density of a given patch from the gradient integral image. + x1 = rect[1] + x2 = rect[1]+rect[3] + y1 = rect[0] + y2 = rect[0]+rect[2] + value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1] + return value + +# Adaptively select patches +def adaptiveselection(integral_grad, patch_bound_list, gf): + patchlist = {} + count = 0 + height, width = integral_grad.shape + + search_step = int(32/factor) + + # Go through all patches + for c in range(len(patch_bound_list)): + # Get patch + bbox = patch_bound_list[str(c)]['rect'] + + # Compute the amount of gradients present in the patch from the integral image. + cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3]) + + # Check if patching is beneficial by comparing the gradient density of the patch to + # the gradient density of the whole image + if cgf >= gf: + bbox_test = bbox.copy() + patchlist[str(count)] = {} + + # Enlarge each patch until the gradient density of the patch is equal + # to the whole image gradient density + while True: + + bbox_test[0] = bbox_test[0] - int(search_step/2) + bbox_test[1] = bbox_test[1] - int(search_step/2) + + bbox_test[2] = bbox_test[2] + search_step + bbox_test[3] = bbox_test[3] + search_step + + # Check if we are still within the image + if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \ + or bbox_test[0] + bbox_test[2] >= width: + break + + # Compare gradient density + cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3]) + if cgf < gf: + break + bbox = bbox_test.copy() + + # Add patch to selected patches + patchlist[str(count)]['rect'] = bbox + patchlist[str(count)]['size'] = bbox[2] + count = count + 1 + + # Return selected patches + return patchlist + +def impatch(image, rect): + # Extract the given patch pixels from a given image. + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + image_patch = image[h1:h2, w1:w2] + return image_patch + +class ImageandPatchs: + def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1): + self.root_dir = root_dir + self.patchsinfo = patchsinfo + self.name = name + self.patchs = patchsinfo + self.scale = scale + + self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)), + interpolation=cv2.INTER_CUBIC) + + self.do_have_estimate = False + self.estimation_updated_image = None + self.estimation_base_image = None + + def __len__(self): + return len(self.patchs) + + def set_base_estimate(self, est): + self.estimation_base_image = est + if self.estimation_updated_image is not None: + self.do_have_estimate = True + + def set_updated_estimate(self, est): + self.estimation_updated_image = est + if self.estimation_base_image is not None: + self.do_have_estimate = True + + def __getitem__(self, index): + patch_id = int(self.patchs[index][0]) + rect = np.array(self.patchs[index][1]['rect']) + msize = self.patchs[index][1]['size'] + + ## applying scale to rect: + rect = np.round(rect * self.scale) + rect = rect.astype('int') + msize = round(msize * self.scale) + + patch_rgb = impatch(self.rgb_image, rect) + if self.do_have_estimate: + patch_whole_estimate_base = impatch(self.estimation_base_image, rect) + patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect) + return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base, + 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect, + 'size': msize, 'id': patch_id} + else: + return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id} + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + """ + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + """ + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt + + +def estimateboost(img, model, model_type, pix2pixmodel, max_res=512, depthmap_script_boost_rmax=None): + global whole_size_threshold + + # get settings + if depthmap_script_boost_rmax: + whole_size_threshold = depthmap_script_boost_rmax + + if model_type == 0: #leres + net_receptive_field_size = 448 + patch_netsize = 2 * net_receptive_field_size + elif model_type == 1: #dpt_beit_large_512 + net_receptive_field_size = 512 + patch_netsize = 2 * net_receptive_field_size + else: #other midas + net_receptive_field_size = 384 + patch_netsize = 2 * net_receptive_field_size + + gc.collect() + torch_gc() + + # Generate mask used to smoothly blend the local pathc estimations to the base estimate. + # It is arbitrarily large to avoid artifacts during rescaling for each crop. + mask_org = generatemask((3000, 3000)) + mask = mask_org.copy() + + # Value x of R_x defined in the section 5 of the main paper. + r_threshold_value = 0.2 + #if R0: + # r_threshold_value = 0 + + input_resolution = img.shape + scale_threshold = 3 # Allows up-scaling with a scale up to 3 + + # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the + # supplementary material. + whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold) + + # print('wholeImage being processed in :', whole_image_optimal_size) + + # Generate the base estimate using the double estimation. + whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel) + + # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select + # small high-density regions of the image. + global factor + factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2) + # print('Adjust factor is:', 1/factor) + + # Check if Local boosting is beneficial. + if max_res < whole_image_optimal_size: + # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result") + return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) + + # Compute the default target resolution. + if img.shape[0] > img.shape[1]: + a = 2 * whole_image_optimal_size + b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0]) + else: + a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1]) + b = 2 * whole_image_optimal_size + b = int(round(b / factor)) + a = int(round(a / factor)) + + """ + # recompute a, b and saturate to max res. + if max(a,b) > max_res: + print('Default Res is higher than max-res: Reducing final resolution') + if img.shape[0] > img.shape[1]: + a = max_res + b = round(max_res * img.shape[1] / img.shape[0]) + else: + a = round(max_res * img.shape[0] / img.shape[1]) + b = max_res + b = int(b) + a = int(a) + """ + + img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC) + + # Extract selected patches for local refinement + base_size = net_receptive_field_size * 2 + patchset = generatepatchs(img, base_size) + + # print('Target resolution: ', img.shape) + + # Computing a scale in case user prompted to generate the results as the same resolution of the input. + # Notice that our method output resolution is independent of the input resolution and this parameter will only + # enable a scaling operation during the local patch merge implementation to generate results with the same resolution + # as the input. + """ + if output_resolution == 1: + mergein_scale = input_resolution[0] / img.shape[0] + print('Dynamicly change merged-in resolution; scale:', mergein_scale) + else: + mergein_scale = 1 + """ + # always rescale to input res for now + mergein_scale = input_resolution[0] / img.shape[0] + + imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale) + whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale), + round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC) + imageandpatchs.set_base_estimate(whole_estimate_resized.copy()) + imageandpatchs.set_updated_estimate(whole_estimate_resized.copy()) + + print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2]) + print('Patches to process: '+str(len(imageandpatchs))) + + # Enumerate through all patches, generate their estimations and refining the base estimate. + for patch_ind in range(len(imageandpatchs)): + + # Get patch information + patch = imageandpatchs[patch_ind] # patch object + patch_rgb = patch['patch_rgb'] # rgb patch + patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base + rect = patch['rect'] # patch size and location + patch['id'] # patch ID + org_size = patch_whole_estimate_base.shape # the original size from the unscaled input + print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect) + + # We apply double estimation for patches. The high resolution value is fixed to twice the receptive + # field size of the network for patches to accelerate the process. + patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel) + patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Merging the patch estimation into the base estimate using our merge network: + # We feed the patch estimation and the same region from the updated base estimate to the merge network + # to generate the target estimate for the corresponding region. + pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation) + + # Run merging network + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + mapped = prediction_mapped + + # We use a simple linear polynomial to make sure the result of the merge network would match the values of + # base estimate + p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1) + merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape) + + merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC) + + # Get patch size and location + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + + # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size + # and resize it to our needed size while merging the patches. + if mask.shape != org_size: + mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR) + + tobemergedto = imageandpatchs.estimation_updated_image + + # Update the whole estimation: + # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless + # blending at the boundaries of the patch region. + tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask) + imageandpatchs.set_updated_estimate(tobemergedto) + + # output + return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) diff --git a/modules/control/proc/leres/leres/multi_depth_model_woauxi.py b/modules/control/proc/leres/leres/multi_depth_model_woauxi.py new file mode 100644 index 000000000..c1266bef1 --- /dev/null +++ b/modules/control/proc/leres/leres/multi_depth_model_woauxi.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +from . import network_auxi as network +from .net_tools import get_func + + +class RelDepthModel(nn.Module): + def __init__(self, backbone='resnet50'): + super(RelDepthModel, self).__init__() + if backbone == 'resnet50': + encoder = 'resnet50_stride32' + elif backbone == 'resnext101': + encoder = 'resnext101_stride32x8d' + self.depth_model = DepthModel(encoder) + + def inference(self, rgb): + input = rgb.to(self.depth_model.device) + depth = self.depth_model(input) + #pred_depth_out = depth - depth.min() + 0.01 + return depth #pred_depth_out + + +class DepthModel(nn.Module): + def __init__(self, encoder): + super(DepthModel, self).__init__() + backbone = network.__name__.split('.')[-1] + '.' + encoder + self.encoder_modules = get_func(backbone)() + self.decoder_modules = network.Decoder() + + def forward(self, x): + lateral_out = self.encoder_modules(x) + out_logit = self.decoder_modules(lateral_out) + return out_logit diff --git a/modules/control/proc/leres/leres/net_tools.py b/modules/control/proc/leres/leres/net_tools.py new file mode 100644 index 000000000..8d2340000 --- /dev/null +++ b/modules/control/proc/leres/leres/net_tools.py @@ -0,0 +1,54 @@ +import os +from collections import OrderedDict +import importlib +import torch + + +def get_func(func_name): + """Helper to return a function object by name. func_name must identify a + function in this module or the path to a function relative to the base + 'modeling' module. + """ + if func_name == '': + return None + try: + parts = func_name.split('.') + # Refers to a function in this module + if len(parts) == 1: + return globals()[parts[0]] + # Otherwise, assume we're referencing a module under modeling + module_name = 'modules.control.proc.leres.leres.' + '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + return getattr(module, parts[-1]) + except Exception: + print('Failed to find function: %s', func_name) + raise + +def load_ckpt(args, depth_model, shift_model, focal_model): + """ + Load checkpoint. + """ + if os.path.isfile(args.load_ckpt): + print("loading checkpoint %s" % args.load_ckpt) + checkpoint = torch.load(args.load_ckpt) + if shift_model is not None: + shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), + strict=True) + if focal_model is not None: + focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), + strict=True) + depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), + strict=True) + del checkpoint + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict diff --git a/modules/control/proc/leres/leres/network_auxi.py b/modules/control/proc/leres/leres/network_auxi.py new file mode 100644 index 000000000..44d96423e --- /dev/null +++ b/modules/control/proc/leres/leres/network_auxi.py @@ -0,0 +1,419 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + +from . import Resnet, Resnext_torch + + +def resnet50_stride32(): + return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2]) + +def resnext101_stride32x8d(): + return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2]) + + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = [2,2,2,2] + self.outchannels = 1 + + self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) + self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) + self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) + + self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) + self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) + self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) + + self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2) + self._init_params() + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, features): + x_32x = self.conv(features[3]) # 1/32 + x_32 = self.conv1(x_32x) + x_16 = self.upsample(x_32) # 1/16 + + x_8 = self.ffm2(features[2], x_16) # 1/8 + x_4 = self.ffm1(features[1], x_8) # 1/4 + x_2 = self.ffm0(features[0], x_4) # 1/2 + #----------------------------------------- + x = self.outconv(x_2) # original size + return x + +class DepthNet(nn.Module): + __factory = { + 18: Resnet.resnet18, + 34: Resnet.resnet34, + 50: Resnet.resnet50, + 101: Resnet.resnet101, + 152: Resnet.resnet152 + } + def __init__(self, + backbone='resnet', + depth=50, + upfactors=None): + if upfactors is None: + upfactors = [2, 2, 2, 2] + super(DepthNet, self).__init__() + self.backbone = backbone + self.depth = depth + self.pretrained = False + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = upfactors + self.outchannels = 1 + + # Build model + if self.backbone == 'resnet': + if self.depth not in DepthNet.__factory: + raise KeyError("Unsupported depth:", self.depth) + self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained) + elif self.backbone == 'resnext101_32x8d': + self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained) + else: + self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained) + + def forward(self, x): + x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4 + return x + + +class FTB(nn.Module): + def __init__(self, inchannels, midchannels=512): + super(FTB, self).__init__() + self.in1 = inchannels + self.mid = midchannels + self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, + bias=True) + # NN.BatchNorm2d + self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.mid), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True)) + self.relu = nn.ReLU(inplace=True) + + self.init_params() + + def forward(self, x): + x = self.conv1(x) + x = x + self.conv_branch(x) + x = self.relu(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class ATA(nn.Module): + def __init__(self, inchannels, reduction=8): + super(ATA, self).__init__() + self.inchannels = inchannels + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction), + nn.ReLU(inplace=True), + nn.Linear(self.inchannels // reduction, self.inchannels), + nn.Sigmoid()) + self.init_params() + + def forward(self, low_x, high_x): + n, c, _, _ = low_x.size() + x = torch.cat([low_x, high_x], 1) + x = self.avg_pool(x) + x = x.view(n, -1) + x = self.fc(x).view(n, c, 1, 1) + x = low_x * x + high_x + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal_(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FFM(nn.Module): + def __init__(self, inchannels, midchannels, outchannels, upfactor=2): + super(FFM, self).__init__() + self.inchannels = inchannels + self.midchannels = midchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) + # self.ata = ATA(inchannels = self.midchannels) + self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) + + self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) + + self.init_params() + + def forward(self, low_x, high_x): + x = self.ftb1(low_x) + x = x + high_x + x = self.ftb2(x) + x = self.upsample(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class AO(nn.Module): + # Adaptive output module + def __init__(self, inchannels, outchannels, upfactor=2): + super(AO, self).__init__() + self.inchannels = inchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.adapt_conv = nn.Sequential( + nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.inchannels // 2), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) + + self.init_params() + + def forward(self, x): + x = self.adapt_conv(x) + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + + +# ============================================================================================================== + + +class ResidualConv(nn.Module): + def __init__(self, inchannels): + super(ResidualConv, self).__init__() + # NN.BatchNorm2d + self.conv = nn.Sequential( + # nn.BatchNorm2d(num_features=inchannels), + nn.ReLU(inplace=False), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True) + nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1, + bias=False), + nn.BatchNorm2d(num_features=inchannels / 2), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, + bias=False) + ) + self.init_params() + + def forward(self, x): + x = self.conv(x) + x + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FeatureFusion(nn.Module): + def __init__(self, inchannels, outchannels): + super(FeatureFusion, self).__init__() + self.conv = ResidualConv(inchannels=inchannels) + # NN.BatchNorm2d + self.up = nn.Sequential(ResidualConv(inchannels=inchannels), + nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3, + stride=2, padding=1, output_padding=1), + nn.BatchNorm2d(num_features=outchannels), + nn.ReLU(inplace=True)) + + def forward(self, lowfeat, highfeat): + return self.up(highfeat + self.conv(lowfeat)) + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class SenceUnderstand(nn.Module): + def __init__(self, channels): + super(SenceUnderstand, self).__init__() + self.channels = channels + self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), + nn.ReLU(inplace=True)) + self.pool = nn.AdaptiveAvgPool2d(8) + self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels), + nn.ReLU(inplace=True)) + self.conv2 = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0), + nn.ReLU(inplace=True)) + self.initial_params() + + def forward(self, x): + n, c, h, w = x.size() + x = self.conv1(x) + x = self.pool(x) + x = x.view(n, -1) + x = self.fc(x) + x = x.view(n, self.channels, 1, 1) + x = self.conv2(x) + x = x.repeat(1, 1, h, w) + return x + + def initial_params(self, dev=0.01): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.ConvTranspose2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, dev) + + +if __name__ == '__main__': + net = DepthNet(depth=50, pretrained=True) + print(net) + inputs = torch.ones(4,3,128,128) + out = net(inputs) + print(out.size()) + diff --git a/modules/control/proc/leres/pix2pix/LICENSE b/modules/control/proc/leres/pix2pix/LICENSE new file mode 100644 index 000000000..38b1a24fd --- /dev/null +++ b/modules/control/proc/leres/pix2pix/LICENSE @@ -0,0 +1,19 @@ +https://github.com/compphoto/BoostingMonocularDepth + +Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved. + +This software is for academic use only. A redistribution of this +software, with or without modifications, has to be for academic +use only, while giving the appropriate credit to the original +authors of the software. The methods implemented as a part of +this software may be covered under patents or patent applications. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/modules/control/proc/leres/pix2pix/__init__.py b/modules/control/proc/leres/pix2pix/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/leres/pix2pix/models/__init__.py b/modules/control/proc/leres/pix2pix/models/__init__.py new file mode 100644 index 000000000..936c0e7d3 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "modules.control.proc.leres.pix2pix.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/modules/control/proc/leres/pix2pix/models/base_model.py b/modules/control/proc/leres/pix2pix/models/base_model.py new file mode 100644 index 000000000..64af9aff6 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/models/base_model.py @@ -0,0 +1,242 @@ +import gc +import os +from abc import ABC, abstractmethod +from collections import OrderedDict + +import torch + +from modules.control.util import torch_gc +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + It also calls to produce additional visualization results + """ + self.forward() + self.compute_visuals() + + def compute_visuals(self): # noqa + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def unload_network(self, name): + """Unload network and gc. + """ + if isinstance(name, str): + net = getattr(self, 'net' + name) + del net + gc.collect() + torch_gc() + return None + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + # print('Loading depth boost model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/modules/control/proc/leres/pix2pix/models/base_model_hg.py b/modules/control/proc/leres/pix2pix/models/base_model_hg.py new file mode 100644 index 000000000..1709accdf --- /dev/null +++ b/modules/control/proc/leres/pix2pix/models/base_model_hg.py @@ -0,0 +1,58 @@ +import os +import torch + +class BaseModelHG(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '_%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(device_id=gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print(save_path) + model = torch.load(save_path) + return model + # network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass diff --git a/modules/control/proc/leres/pix2pix/models/networks.py b/modules/control/proc/leres/pix2pix/models/networks.py new file mode 100644 index 000000000..1f076f89f --- /dev/null +++ b/modules/control/proc/leres/pix2pix/models/networks.py @@ -0,0 +1,628 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + # print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=None): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if gpu_ids is None: + gpu_ids = [] + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=None): + """Create a generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm (str) -- the name of normalization layers used in the network: batch | instance | none + use_dropout (bool) -- if use dropout layers. + init_type (str) -- the name of our initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a generator + + Our current implementation provides two types of generators: + U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) + The original U-Net paper: https://arxiv.org/abs/1505.04597 + + Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) + Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). + + + The generator has been initialized by . It uses RELU for non-linearity. + """ + if gpu_ids is None: + gpu_ids = [] + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'resnet_9blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif netG == 'resnet_6blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif netG == 'resnet_12blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12) + elif netG == 'unet_128': + net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_256': + net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_672': + net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_960': + net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_1024': + net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=None): + """Create a discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the first conv layer + netD (str) -- the architecture's name: basic | n_layers | pixel + n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str) -- the type of normalization layers used in the network. + init_type (str) -- the name of the initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a discriminator + + Our current implementation provides three types of discriminators: + [basic]: 'PatchGAN' classifier described in the original pix2pix paper. + It can classify whether 70x70 overlapping patches are real or fake. + Such a patch-level discriminator architecture has fewer parameters + than a full-image discriminator and can work on arbitrarily-sized images + in a fully convolutional fashion. + + [n_layers]: With this mode, you can specify the number of conv layers in the discriminator + with the parameter (default=3 as used in [basic] (PatchGAN).) + + [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. + It encourages greater color diversity but has no effect on spatial statistics. + + The discriminator has been initialized by . It uses Leakly RELU for non-linearity. + """ + if gpu_ids is None: + gpu_ids = [] + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + + Parameters: + prediction (tensor) - - tpyically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str) -- if we mix real and fake data or not [real | fake | mixed]. + constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class ResnetGenerator(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for _i in range(n_blocks): # add ResNet blocks + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) diff --git a/modules/control/proc/leres/pix2pix/models/pix2pix4depth_model.py b/modules/control/proc/leres/pix2pix/models/pix2pix4depth_model.py new file mode 100644 index 000000000..aac9ae83a --- /dev/null +++ b/modules/control/proc/leres/pix2pix/models/pix2pix4depth_model.py @@ -0,0 +1,155 @@ +import torch +from .base_model import BaseModel +from . import networks + + +class Pix2Pix4DepthModel(BaseModel): + """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. + + The model training requires '--dataset_mode aligned' dataset. + By default, it uses a '--netG unet256' U-Net generator, + a '--netD basic' discriminator (PatchGAN), + and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + + pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + + For pix2pix, we do not use image buffer + The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 + By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. + """ + # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) + parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla',) + parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss') + return parser + + def __init__(self, opt): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] + # self.loss_names = ['G_L1'] + + # specify the images you want to save/display. The training/test scripts will call + if self.isTrain: + self.visual_names = ['outer','inner', 'fake_B', 'real_B'] + else: + self.visual_names = ['fake_B'] + + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G','D'] + else: # during test time, only load G + self.model_names = ['G'] + + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none', + False, 'normal', 0.02, self.gpu_ids) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + def set_input_train(self, input): + self.outer = input['data_outer'].to(self.device) + self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False) + + self.inner = input['data_inner'].to(self.device) + self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False) + + self.image_paths = input['image_path'] + + if self.isTrain: + self.gtfake = input['data_gtfake'].to(self.device) + self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False) + self.real_B = self.gtfake + + self.real_A = torch.cat((self.outer, self.inner), 1) + + def set_input(self, outer, inner): + inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0) + outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0) + + inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner)) + outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer)) + + inner = self.normalize(inner) + outer = self.normalize(outer) + + self.real_A = torch.cat((outer, inner), 1).to(self.device) + + + def normalize(self, input): + input = input * 2 + input = input - 1 + return input + + def forward(self): + """Run forward pass; called by both functions and .""" + self.fake_B = self.netG(self.real_A) # G(A) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + # combine loss and calculate gradients + self.loss_G = self.loss_G_L1 + self.loss_G_GAN + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights diff --git a/modules/control/proc/leres/pix2pix/options/__init__.py b/modules/control/proc/leres/pix2pix/options/__init__.py new file mode 100644 index 000000000..e7eedebe5 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/modules/control/proc/leres/pix2pix/options/base_options.py b/modules/control/proc/leres/pix2pix/options/base_options.py new file mode 100644 index 000000000..533a1e88a --- /dev/null +++ b/modules/control/proc/leres/pix2pix/options/base_options.py @@ -0,0 +1,156 @@ +import argparse +import os +from ...pix2pix.util import util +# import torch +from ...pix2pix import models +# import pix2pix.data +import numpy as np + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here') + # model parameters + parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') + parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + # dataset parameters + parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=672, help='scale images to this size') + parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size') + parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--data_dir', type=str, required=False, + help='input files directory images can be .png .jpg .tiff') + parser.add_argument('--output_dir', type=str, required=False, + help='result dir. result depth will be png. vides are JMPG as avi') + parser.add_argument('--savecrops', type=int, required=False) + parser.add_argument('--savewholeest', type=int, required=False) + parser.add_argument('--output_resolution', type=int, required=False, + help='0 for no restriction 1 for resize to input size') + parser.add_argument('--net_receptive_field_size', type=int, required=False) + parser.add_argument('--pix2pixsize', type=int, required=False) + parser.add_argument('--generatevideo', type=int, required=False) + parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL') + parser.add_argument('--R0', action='store_true') + parser.add_argument('--R20', action='store_true') + parser.add_argument('--Final', action='store_true') + parser.add_argument('--colorize_results', action='store_true') + parser.add_argument('--max_res', type=float, default=np.inf) + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # modify dataset-related parser options + # dataset_name = opt.dataset_mode + # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name) + # parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + #return parser.parse_args() #EVIL + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/modules/control/proc/leres/pix2pix/options/test_options.py b/modules/control/proc/leres/pix2pix/options/test_options.py new file mode 100644 index 000000000..a3424b5e3 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/options/test_options.py @@ -0,0 +1,22 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') + # rewrite devalue values + parser.set_defaults(model='pix2pix4depth') + # To avoid cropping, the load_size should be the same as crop_size + parser.set_defaults(load_size=parser.get_default('crop_size')) + self.isTrain = False + return parser diff --git a/modules/control/proc/leres/pix2pix/util/__init__.py b/modules/control/proc/leres/pix2pix/util/__init__.py new file mode 100644 index 000000000..ae36f63d8 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/modules/control/proc/leres/pix2pix/util/util.py b/modules/control/proc/leres/pix2pix/util/util.py new file mode 100644 index 000000000..8a7aceaa0 --- /dev/null +++ b/modules/control/proc/leres/pix2pix/util/util.py @@ -0,0 +1,105 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint16): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array + image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) # + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + + image_pil = image_pil.convert('I;16') + + # image_pil = Image.fromarray(image_numpy) + # h, w, _ = image_numpy.shape + # + # if aspect_ratio > 1.0: + # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + # if aspect_ratio < 1.0: + # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/modules/control/proc/lineart.py b/modules/control/proc/lineart.py new file mode 100644 index 000000000..7f7aef10a --- /dev/null +++ b/modules/control/proc/lineart.py @@ -0,0 +1,166 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image + +norm_layer = nn.InstanceNorm2d + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [ nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features*2 + for _ in range(2): + model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features*2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features//2 + for _ in range(2): + model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features//2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [ nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): # pylint: disable=unused-argument + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class LineartDetector: + def __init__(self, model, coarse_model): + self.model = model + self.model_coarse = coarse_model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filename=None, cache_dir=None): + filename = filename or "sk_model.pth" + coarse_filename = coarse_filename or "sk_model2.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + coarse_model_path = os.path.join(pretrained_model_or_path, coarse_filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + coarse_model_path = hf_hub_download(pretrained_model_or_path, coarse_filename, cache_dir=cache_dir) + + model = Generator(3, 1, 3) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + model.eval() + + coarse_model = Generator(3, 1, 3) + coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu'))) + coarse_model.eval() + + return cls(model, coarse_model) + + def to(self, device): + self.model.to(device) + self.model_coarse.to(device) + return self + + def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + model = self.model_coarse if coarse else self.model + assert input_image.ndim == 3 + image = input_image + image = torch.from_numpy(image).float().to(device) + image = image / 255.0 + image = rearrange(image, 'h w c -> 1 c h w') + line = model(image)[0][0] + + line = line.cpu().numpy() + line = (line * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = line + + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + detected_map = 255 - detected_map + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/lineart_anime.py b/modules/control/proc/lineart_anime.py new file mode 100644 index 000000000..9eb4fcc09 --- /dev/null +++ b/modules/control/proc/lineart_anime.py @@ -0,0 +1,188 @@ +import functools +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): # pylint: disable=redefined-builtin + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class LineartAnimeDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None): + filename = filename or "netG.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) + ckpt = torch.load(model_path) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + + return cls(net) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + H, W, _C = input_image.shape + Hn = 256 * int(np.ceil(float(H) / 256.0)) + Wn = 256 * int(np.ceil(float(W) / 256.0)) + img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC) + image_feed = torch.from_numpy(img).float().to(device) + image_feed = image_feed / 127.5 - 1.0 + image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + + line = self.model(image_feed)[0, 0] * 127.5 + 127.5 + line = line.cpu().numpy() + + line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC) + line = line.clip(0, 255).astype(np.uint8) + + detected_map = line + + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + detected_map = 255 - detected_map + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/mediapipe_face.py b/modules/control/proc/mediapipe_face.py new file mode 100644 index 000000000..187f14765 --- /dev/null +++ b/modules/control/proc/mediapipe_face.py @@ -0,0 +1,51 @@ +import warnings +from typing import Union +import cv2 +import numpy as np +from PIL import Image +from modules.control.util import HWC3, resize_image + + +class MediapipeFaceDetector: + def __call__(self, + input_image: Union[np.ndarray, Image.Image] = None, + max_faces: int = 1, + min_confidence: float = 0.5, + output_type: str = "pil", + detect_resolution: int = 512, + image_resolution: int = 512, + **kwargs): + + from .mediapipe_face_util import generate_annotation + if "image" in kwargs: + warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("image") + if input_image is None: + raise ValueError("input_image must be defined.") + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + detected_map = generate_annotation(input_image, max_faces, min_confidence) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/mediapipe_face_util.py b/modules/control/proc/mediapipe_face_util.py new file mode 100644 index 000000000..40ec0d204 --- /dev/null +++ b/modules/control/proc/mediapipe_face_util.py @@ -0,0 +1,162 @@ +from typing import Mapping +import numpy as np +from modules.shared import log + +try: + import mediapipe as mp +except ImportError: + log.error("Control processor MediaPipe: mediapipe not installed") + mp = None + +if mp: + mp_drawing = mp.solutions.drawing_utils + mp_drawing_styles = mp.solutions.drawing_styles + mp_face_detection = mp.solutions.face_detection # Only for counting faces. + mp_face_mesh = mp.solutions.face_mesh + mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION + mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS + mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS + + DrawingSpec = mp.solutions.drawing_styles.DrawingSpec + PoseLandmark = mp.solutions.drawing_styles.PoseLandmark + + min_face_size_pixels: int = 64 + f_thick = 2 + f_rad = 1 + right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) + right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) + right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) + left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) + mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad) + head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) + + # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. + face_connection_spec = {} + for edge in mp_face_mesh.FACEMESH_FACE_OVAL: + face_connection_spec[edge] = head_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYE: + face_connection_spec[edge] = left_eye_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: + face_connection_spec[edge] = left_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: + # face_connection_spec[edge] = left_iris_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: + face_connection_spec[edge] = right_eye_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + face_connection_spec[edge] = right_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: + # face_connection_spec[edge] = right_iris_draw + for edge in mp_face_mesh.FACEMESH_LIPS: + face_connection_spec[edge] = mouth_draw + iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} + + +def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2): + """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all + landmarks. Until our PR is merged into mediapipe, we need this separate method.""" + if len(image.shape) != 3: + raise ValueError("Input image must be H,W,C.") + image_rows, image_cols, image_channels = image.shape + if image_channels != 3: # BGR channels + raise ValueError('Input image must contain three channel bgr data.') + for idx, landmark in enumerate(landmark_list.landmark): + if ( + (landmark.HasField('visibility') and landmark.visibility < 0.9) or + (landmark.HasField('presence') and landmark.presence < 0.5) + ): + continue + if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: + continue + image_x = int(image_cols*landmark.x) + image_y = int(image_rows*landmark.y) + draw_color = None + if isinstance(drawing_spec, Mapping): + if drawing_spec.get(idx) is None: + continue + else: + draw_color = drawing_spec[idx].color + elif isinstance(drawing_spec, DrawingSpec): + draw_color = drawing_spec.color + image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color + + +def reverse_channels(image): + """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.""" + # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order. + # im[:,:,::[2,1,0]] would also work but makes a copy of the data. + return image[:, :, ::-1] + + +def generate_annotation( + img_rgb, + max_faces: int, + min_confidence: float +): + """ + Find up to 'max_faces' inside the provided input image. + If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many + pixels in the image. + """ + if mp is None: + return img_rgb + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=max_faces, + refine_landmarks=True, + min_detection_confidence=min_confidence, + ) as facemesh: + img_height, img_width, img_channels = img_rgb.shape + assert img_channels == 3 + + results = facemesh.process(img_rgb).multi_face_landmarks + + if results is None: + print("No faces detected in controlnet image for Mediapipe face annotator.") + return np.zeros_like(img_rgb) + + # Filter faces that are too small + filtered_landmarks = [] + for lm in results: + landmarks = lm.landmark + face_rect = [ + landmarks[0].x, + landmarks[0].y, + landmarks[0].x, + landmarks[0].y, + ] # Left, up, right, down. + for i in range(len(landmarks)): + face_rect[0] = min(face_rect[0], landmarks[i].x) + face_rect[1] = min(face_rect[1], landmarks[i].y) + face_rect[2] = max(face_rect[2], landmarks[i].x) + face_rect[3] = max(face_rect[3], landmarks[i].y) + if min_face_size_pixels > 0: + face_width = abs(face_rect[2] - face_rect[0]) + face_height = abs(face_rect[3] - face_rect[1]) + face_width_pixels = face_width * img_width + face_height_pixels = face_height * img_height + face_size = min(face_width_pixels, face_height_pixels) + if face_size >= min_face_size_pixels: + filtered_landmarks.append(lm) + else: + filtered_landmarks.append(lm) + + # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start. + empty = np.zeros_like(img_rgb) + + # Draw detected faces: + for face_landmarks in filtered_landmarks: + mp_drawing.draw_landmarks( + empty, + face_landmarks, + connections=face_connection_spec.keys(), + landmark_drawing_spec=None, + connection_drawing_spec=face_connection_spec + ) + draw_pupils(empty, face_landmarks, iris_landmark_spec, 2) + + # Flip BGR back to RGB. + empty = reverse_channels(empty).copy() + + return empty diff --git a/modules/control/proc/midas/LICENSE b/modules/control/proc/midas/LICENSE new file mode 100644 index 000000000..277b5c11b --- /dev/null +++ b/modules/control/proc/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/modules/control/proc/midas/__init__.py b/modules/control/proc/midas/__init__.py new file mode 100644 index 000000000..13a5ad061 --- /dev/null +++ b/modules/control/proc/midas/__init__.py @@ -0,0 +1,94 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", filename=None, cache_dir=None): + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + else: + filename = filename or "dpt_hybrid-midas-501f0c75.pt" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + model = MiDaSInference(model_type=model_type, model_path=model_path) + + return cls(model) + + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, image_resolution=512, output_type=None): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_depth = input_image + image_depth = torch.from_numpy(image_depth).float() + image_depth = image_depth.to(device) + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + if depth_and_normal: + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] + + depth_image = HWC3(depth_image) + if depth_and_normal: + normal_image = HWC3(normal_image) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR) + if depth_and_normal: + normal_image = cv2.resize(normal_image, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + depth_image = Image.fromarray(depth_image) + if depth_and_normal: + normal_image = Image.fromarray(normal_image) + + if depth_and_normal: + return depth_image, normal_image + else: + return depth_image diff --git a/modules/control/proc/midas/api.py b/modules/control/proc/midas/api.py new file mode 100644 index 000000000..b1540cd9e --- /dev/null +++ b/modules/control/proc/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from modules.control.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + raise AssertionError(f"model_type '{model_type}' not implemented, use: --model_type large") + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type, model_path=None): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = model_path or ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + raise AssertionError + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type, model_path): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type, model_path) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + prediction = self.model(x) + return prediction + diff --git a/modules/control/proc/midas/midas/__init__.py b/modules/control/proc/midas/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/midas/midas/base_model.py b/modules/control/proc/midas/midas/base_model.py new file mode 100644 index 000000000..5cf430239 --- /dev/null +++ b/modules/control/proc/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/modules/control/proc/midas/midas/blocks.py b/modules/control/proc/midas/midas/blocks.py new file mode 100644 index 000000000..cb840ded3 --- /dev/null +++ b/modules/control/proc/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + raise AssertionError + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand is True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/modules/control/proc/midas/midas/dpt_depth.py b/modules/control/proc/midas/midas/dpt_depth.py new file mode 100644 index 000000000..4429b7f94 --- /dev/null +++ b/modules/control/proc/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last is True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/modules/control/proc/midas/midas/midas_net.py b/modules/control/proc/midas/midas/midas_net.py new file mode 100644 index 000000000..8a9549778 --- /dev/null +++ b/modules/control/proc/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/modules/control/proc/midas/midas/midas_net_custom.py b/modules/control/proc/midas/midas/midas_net_custom.py new file mode 100644 index 000000000..cba1bcfff --- /dev/null +++ b/modules/control/proc/midas/midas/midas_net_custom.py @@ -0,0 +1,130 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks=None): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + if blocks is None: + blocks = {"expand": True} + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] is True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last is True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name diff --git a/modules/control/proc/midas/midas/transforms.py b/modules/control/proc/midas/midas/transforms.py new file mode 100644 index 000000000..350cbc116 --- /dev/null +++ b/modules/control/proc/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/modules/control/proc/midas/midas/vit.py b/modules/control/proc/midas/midas/vit.py new file mode 100644 index 000000000..f268a9fc4 --- /dev/null +++ b/modules/control/proc/midas/midas/vit.py @@ -0,0 +1,501 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + raise AssertionError("wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'") + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + use_readout="ignore", + start_index=1, +): + if hooks is None: + hooks = [2, 5, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [96, 192, 384, 768] + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + if hooks is None: + hooks = [0, 1, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [256, 512, 768, 768] + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only is True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only is True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks is None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/modules/control/proc/midas/utils.py b/modules/control/proc/midas/utils.py new file mode 100644 index 000000000..9a9d3b5b6 --- /dev/null +++ b/modules/control/proc/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/modules/control/proc/mlsd/LICENSE b/modules/control/proc/mlsd/LICENSE new file mode 100644 index 000000000..d855c6db4 --- /dev/null +++ b/modules/control/proc/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/modules/control/proc/mlsd/__init__.py b/modules/control/proc/mlsd/__init__.py new file mode 100644 index 000000000..456e1050d --- /dev/null +++ b/modules/control/proc/mlsd/__init__.py @@ -0,0 +1,78 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + + +class MLSDdetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None): + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/mlsd_large_512_fp32.pth" + else: + filename = filename or "mlsd_large_512_fp32.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception: + pass + + detected_map = img_output[:, :, 0] + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/mlsd/models/__init__.py b/modules/control/proc/mlsd/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/mlsd/models/mbv2_mlsd_large.py b/modules/control/proc/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 000000000..39acf8dd5 --- /dev/null +++ b/modules/control/proc/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x diff --git a/modules/control/proc/mlsd/models/mbv2_mlsd_tiny.py b/modules/control/proc/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 000000000..4f043851e --- /dev/null +++ b/modules/control/proc/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x diff --git a/modules/control/proc/mlsd/utils.py b/modules/control/proc/mlsd/utils.py new file mode 100644 index 000000000..ca8034370 --- /dev/null +++ b/modules/control/proc/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=None, + score_thr=0.10, + dist_thr=20.0): + if input_shape is None: + input_shape = [512, 512] + h, w, _ = image.shape + + device = next(iter(model.parameters())).device + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float() + batch_image = batch_image.to(device) + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=None, + params=None): + ''' + shape = [height, width] + ''' + if params is None: + params = {'score': 0.06, 'outside_ratio': 0.28, 'inside_ratio': 0.45, 'w_overlap': 0.0, 'w_degree': 1.95, 'w_length': 0.0, 'w_area': 1.86, 'w_center': 0.14} + if input_shape is None: + input_shape = [512, 512] + h, w, _ = image.shape + original_shape = [h, w] + device = next(iter(model.parameters())).device + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to(device) + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, _color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, _color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, _color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, _color_info = 3, 'cyan' + else: + corner_info, _color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + square[start_idx] + square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except Exception: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except Exception: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except Exception: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/modules/control/proc/normalbae/LICENSE b/modules/control/proc/normalbae/LICENSE new file mode 100644 index 000000000..16a9d56a3 --- /dev/null +++ b/modules/control/proc/normalbae/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/modules/control/proc/normalbae/__init__.py b/modules/control/proc/normalbae/__init__.py new file mode 100644 index 000000000..852189f68 --- /dev/null +++ b/modules/control/proc/normalbae/__init__.py @@ -0,0 +1,107 @@ +import os +import types +import warnings + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .nets.NNET import NNET + + +# load model +def load_checkpoint(fpath, model): + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + return model + +class NormalBaeDetector: + def __init__(self, model): + self.model = model + self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None): + filename = filename or "scannet.pt" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + args = types.SimpleNamespace() + args.mode = 'client' + args.architecture = 'BN' + args.pretrained = 'scannet' + args.sampling_ratio = 0.4 + args.importance_ratio = 0.7 + model = NNET(args) + model = load_checkpoint(model_path, model) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_normal = input_image + image_normal = torch.from_numpy(image_normal).float().to(device) + image_normal = image_normal / 255.0 + image_normal = rearrange(image_normal, 'h w c -> 1 c h w') + image_normal = self.norm(image_normal) + + normal = self.model(image_normal) + normal = normal[0][-1][:, :3] + # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 + # d = torch.maximum(d, torch.ones_like(d) * 1e-5) + # normal /= d + normal = ((normal + 1) * 0.5).clip(0, 1) + + normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() + normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = normal_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/normalbae/nets/NNET.py b/modules/control/proc/normalbae/nets/NNET.py new file mode 100644 index 000000000..1445fa5a0 --- /dev/null +++ b/modules/control/proc/normalbae/nets/NNET.py @@ -0,0 +1,19 @@ +import torch.nn as nn +from .submodules.encoder import Encoder +from .submodules.decoder import Decoder + + +class NNET(nn.Module): + def __init__(self, args): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(args) + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + return self.decoder.parameters() + + def forward(self, img, **kwargs): + return self.decoder(self.encoder(img), **kwargs) diff --git a/modules/control/proc/normalbae/nets/__init__.py b/modules/control/proc/normalbae/nets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/normalbae/nets/baseline.py b/modules/control/proc/normalbae/nets/baseline.py new file mode 100644 index 000000000..61d610be3 --- /dev/null +++ b/modules/control/proc/normalbae/nets/baseline.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules.submodules import UpSampleBN, norm_normalize + + +# This is the baseline encoder-decoder we used in the ablation study +class NNET(nn.Module): + def __init__(self, args=None): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(num_classes=4) + + def forward(self, x, **kwargs): + out = self.decoder(self.encoder(x), **kwargs) + + # Bilinearly upsample the output to match the input resolution + up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False) + + # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) + up_out = norm_normalize(up_out) + return up_out + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + modules = [self.decoder] + for m in modules: + yield from m.parameters() + + +# Encoder +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) + + # Remove last layer + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for _ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +# Decoder (no pixel-wise MLP, no uncertainty-guided sampling) +class Decoder(nn.Module): + def __init__(self, num_classes=4): + super(Decoder, self).__init__() + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1) + + def forward(self, features): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + x_d0 = self.conv2(x_block4) + x_d1 = self.up1(x_d0, x_block3) + x_d2 = self.up2(x_d1, x_block2) + x_d3 = self.up3(x_d2, x_block1) + x_d4 = self.up4(x_d3, x_block0) + out = self.conv3(x_d4) + return out diff --git a/modules/control/proc/normalbae/nets/submodules/__init__.py b/modules/control/proc/normalbae/nets/submodules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/normalbae/nets/submodules/decoder.py b/modules/control/proc/normalbae/nets/submodules/decoder.py new file mode 100644 index 000000000..993203d17 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/decoder.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points + + +class Decoder(nn.Module): + def __init__(self, args): + super(Decoder, self).__init__() + + # hyper-parameter for sampling + self.sampling_ratio = args.sampling_ratio + self.importance_ratio = args.importance_ratio + + # feature-map + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + if args.architecture == 'BN': + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + + elif args.architecture == 'GN': + self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128) + + else: + raise Exception('invalid architecture') + + # produces 1/8 res output + self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + # produces 1/4 res output + self.out_conv_res4 = nn.Sequential( + nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/2 res output + self.out_conv_res2 = nn.Sequential( + nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/1 res output + self.out_conv_res1 = nn.Sequential( + nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + def forward(self, features, gt_norm_mask=None, mode='test'): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + + # generate feature-map + + x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res + x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res + x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res + x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res + x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res + + # 1/8 res output + out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output + out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output + + ################################################################################################################ + # out_res4 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160] + out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res8_res4.shape + + # samples: [B, 1, N, 2] + point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res4 = out_res8_res4 + + # grid_sample feature-map + feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N) + init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N) + samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized + + for i in range(B): + out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + # try all pixels + out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N) + out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized + out_res4 = out_res4.view(B, 4, H, W) + samples_pred_res4 = point_coords_res4 = None + + ################################################################################################################ + # out_res2 + ################################################################################################################ + + if mode == 'train': + + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res4_res2.shape + + # samples: [B, 1, N, 2] + point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res2 = out_res4_res2 + + # grid_sample feature-map + feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N) + init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N) + samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized + + for i in range(B): + out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N) + out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized + out_res2 = out_res2.view(B, 4, H, W) + samples_pred_res2 = point_coords_res2 = None + + ################################################################################################################ + # out_res1 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res2_res1.shape + + # samples: [B, 1, N, 2] + point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res1 = out_res2_res1 + + # grid_sample feature-map + feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N) + init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N) + samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized + + for i in range(B): + out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N) + out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized + out_res1 = out_res1.view(B, 4, H, W) + samples_pred_res1 = point_coords_res1 = None + + return [out_res8, out_res4, out_res2, out_res1], \ + [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \ + [None, point_coords_res4, point_coords_res2, point_coords_res1] + diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md new file mode 100644 index 000000000..6ead7171c --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md @@ -0,0 +1,555 @@ +# Model Performance Benchmarks + +All benchmarks run as per: + +``` +python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx +python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx +python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 +python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt +python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb +python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb +``` + +## EfficientNet-B0 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 +Time per operator type: + 29.7378 ms. 60.5145%. Conv + 12.1785 ms. 24.7824%. Sigmoid + 3.62811 ms. 7.38297%. SpatialBN + 2.98444 ms. 6.07314%. Mul + 0.326902 ms. 0.665225%. AveragePool + 0.197317 ms. 0.401528%. FC + 0.0852877 ms. 0.173555%. Add + 0.0032607 ms. 0.00663532%. Squeeze + 49.1416 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 95.2696%. Conv + 0.0269508 GFLOP. 3.33857%. SpatialBN + 0.00846444 GFLOP. 1.04855%. Mul + 0.002561 GFLOP. 0.317248%. FC + 0.000210112 GFLOP. 0.0260279%. Add + 0.807256 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 43.0891%. Mul + 43.2015 MB. 31.807%. Conv + 27.2869 MB. 20.0899%. SpatialBN + 5.12912 MB. 3.77631%. FC + 1.6809 MB. 1.23756%. Add + 135.824 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 38.1965%. Mul + 26.9881 MB. 30.4465%. Conv + 26.9508 MB. 30.4044%. SpatialBN + 0.840448 MB. 0.948147%. Add + 0.004 MB. 0.00451258%. FC + 88.6412 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 74.9391%. Conv + 5.124 MB. 24.265%. FC + 0.168064 MB. 0.795877%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 21.1168 MB in Total +``` +### Optimized +``` +Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 +Time per operator type: + 29.776 ms. 65.002%. Conv + 12.2803 ms. 26.8084%. Sigmoid + 3.15073 ms. 6.87815%. Mul + 0.328651 ms. 0.717456%. AveragePool + 0.186237 ms. 0.406563%. FC + 0.0832429 ms. 0.181722%. Add + 0.0026184 ms. 0.00571606%. Squeeze + 45.8078 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 98.5601%. Conv + 0.00846444 GFLOP. 1.08476%. Mul + 0.002561 GFLOP. 0.328205%. FC + 0.000210112 GFLOP. 0.0269269%. Add + 0.780305 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 53.8803%. Mul + 43.2855 MB. 39.8501%. Conv + 5.12912 MB. 4.72204%. FC + 1.6809 MB. 1.54749%. Add + 108.621 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 54.8834%. Mul + 26.9881 MB. 43.7477%. Conv + 0.840448 MB. 1.36237%. Add + 0.004 MB. 0.00648399%. FC + 61.6904 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 75.5403%. Conv + 5.124 MB. 24.4597%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 20.9488 MB in Total +``` + +## EfficientNet-B1 +### Optimized +``` +Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 +Time per operator type: + 45.7915 ms. 66.3206%. Conv + 17.8718 ms. 25.8841%. Sigmoid + 4.44132 ms. 6.43244%. Mul + 0.51001 ms. 0.738658%. AveragePool + 0.233283 ms. 0.337868%. Add + 0.194986 ms. 0.282402%. FC + 0.00268255 ms. 0.00388519%. Squeeze + 69.0456 ms in Total +FLOP per operator type: + 1.37105 GFLOP. 98.7673%. Conv + 0.0138759 GFLOP. 0.99959%. Mul + 0.002561 GFLOP. 0.184489%. FC + 0.000674432 GFLOP. 0.0485847%. Add + 1.38816 GFLOP in Total +Feature Memory Read per operator type: + 94.624 MB. 54.0789%. Mul + 69.8255 MB. 39.9062%. Conv + 5.39546 MB. 3.08357%. Add + 5.12912 MB. 2.93136%. FC + 174.974 MB in Total +Feature Memory Written per operator type: + 55.5035 MB. 54.555%. Mul + 43.5333 MB. 42.7894%. Conv + 2.69773 MB. 2.65163%. Add + 0.004 MB. 0.00393165%. FC + 101.739 MB in Total +Parameter Memory per operator type: + 25.7479 MB. 83.4024%. Conv + 5.124 MB. 16.5976%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 30.8719 MB in Total +``` + +## EfficientNet-B2 +### Optimized +``` +Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 +Time per operator type: + 61.4627 ms. 67.5845%. Conv + 22.7458 ms. 25.0113%. Sigmoid + 5.59931 ms. 6.15701%. Mul + 0.642567 ms. 0.706568%. AveragePool + 0.272795 ms. 0.299965%. Add + 0.216178 ms. 0.237709%. FC + 0.00268895 ms. 0.00295677%. Squeeze + 90.942 ms in Total +FLOP per operator type: + 1.98431 GFLOP. 98.9343%. Conv + 0.0177039 GFLOP. 0.882686%. Mul + 0.002817 GFLOP. 0.140451%. FC + 0.000853984 GFLOP. 0.0425782%. Add + 2.00568 GFLOP in Total +Feature Memory Read per operator type: + 120.609 MB. 54.9637%. Mul + 86.3512 MB. 39.3519%. Conv + 6.83187 MB. 3.11341%. Add + 5.64163 MB. 2.571%. FC + 219.433 MB in Total +Feature Memory Written per operator type: + 70.8155 MB. 54.6573%. Mul + 55.3273 MB. 42.7031%. Conv + 3.41594 MB. 2.63651%. Add + 0.004 MB. 0.00308731%. FC + 129.563 MB in Total +Parameter Memory per operator type: + 30.4721 MB. 84.3913%. Conv + 5.636 MB. 15.6087%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 36.1081 MB in Total +``` + +## MixNet-M +### Optimized +``` +Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 +Time per operator type: + 48.1139 ms. 75.2052%. Conv + 7.1341 ms. 11.1511%. Sigmoid + 2.63706 ms. 4.12189%. SpatialBN + 1.73186 ms. 2.70701%. Mul + 1.38707 ms. 2.16809%. Split + 1.29322 ms. 2.02139%. Concat + 1.00093 ms. 1.56452%. Relu + 0.235309 ms. 0.367803%. Add + 0.221579 ms. 0.346343%. FC + 0.219315 ms. 0.342803%. AveragePool + 0.00250145 ms. 0.00390993%. Squeeze + 63.9768 ms in Total +FLOP per operator type: + 0.675273 GFLOP. 95.5827%. Conv + 0.0221072 GFLOP. 3.12921%. SpatialBN + 0.00538445 GFLOP. 0.762152%. Mul + 0.003073 GFLOP. 0.434973%. FC + 0.000642488 GFLOP. 0.0909421%. Add + 0 GFLOP. 0%. Concat + 0 GFLOP. 0%. Relu + 0.70648 GFLOP in Total +Feature Memory Read per operator type: + 46.8424 MB. 30.502%. Conv + 36.8626 MB. 24.0036%. Mul + 22.3152 MB. 14.5309%. SpatialBN + 22.1074 MB. 14.3955%. Concat + 14.1496 MB. 9.21372%. Relu + 6.15414 MB. 4.00735%. FC + 5.1399 MB. 3.34692%. Add + 153.571 MB in Total +Feature Memory Written per operator type: + 32.7672 MB. 28.4331%. Conv + 22.1072 MB. 19.1831%. Concat + 22.1072 MB. 19.1831%. SpatialBN + 21.5378 MB. 18.689%. Mul + 14.1496 MB. 12.2781%. Relu + 2.56995 MB. 2.23003%. Add + 0.004 MB. 0.00347092%. FC + 115.243 MB in Total +Parameter Memory per operator type: + 13.7059 MB. 68.674%. Conv + 6.148 MB. 30.8049%. FC + 0.104 MB. 0.521097%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Concat + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 19.9579 MB in Total +``` + +## TF MobileNet-V3 Large 1.0 + +### Optimized +``` +Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 +Time per operator type: + 17.437 ms. 80.0087%. Conv + 1.27662 ms. 5.8577%. Add + 1.12759 ms. 5.17387%. Div + 0.701155 ms. 3.21721%. Mul + 0.562654 ms. 2.58171%. Relu + 0.431144 ms. 1.97828%. Clip + 0.156902 ms. 0.719936%. FC + 0.0996858 ms. 0.457402%. AveragePool + 0.00112455 ms. 0.00515993%. Flatten + 21.7939 ms in Total +FLOP per operator type: + 0.43062 GFLOP. 98.1484%. Conv + 0.002561 GFLOP. 0.583713%. FC + 0.00210867 GFLOP. 0.480616%. Mul + 0.00193868 GFLOP. 0.441871%. Add + 0.00151532 GFLOP. 0.345377%. Div + 0 GFLOP. 0%. Relu + 0.438743 GFLOP in Total +Feature Memory Read per operator type: + 34.7967 MB. 43.9391%. Conv + 14.496 MB. 18.3046%. Mul + 9.44828 MB. 11.9307%. Add + 9.26157 MB. 11.6949%. Relu + 6.0614 MB. 7.65395%. Div + 5.12912 MB. 6.47673%. FC + 79.193 MB in Total +Feature Memory Written per operator type: + 17.6247 MB. 35.8656%. Conv + 9.26157 MB. 18.847%. Relu + 8.43469 MB. 17.1643%. Mul + 7.75472 MB. 15.7806%. Add + 6.06128 MB. 12.3345%. Div + 0.004 MB. 0.00813985%. FC + 49.1409 MB in Total +Parameter Memory per operator type: + 16.6851 MB. 76.5052%. Conv + 5.124 MB. 23.4948%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8091 MB in Total +``` + +## MobileNet-V3 (RW) + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 +Time per operator type: + 15.9266 ms. 69.2624%. Conv + 2.36551 ms. 10.2873%. SpatialBN + 1.39102 ms. 6.04936%. Add + 1.30327 ms. 5.66773%. Div + 0.737014 ms. 3.20517%. Mul + 0.639697 ms. 2.78195%. Relu + 0.375681 ms. 1.63378%. Clip + 0.153126 ms. 0.665921%. FC + 0.0993787 ms. 0.432184%. AveragePool + 0.0032632 ms. 0.0141912%. Squeeze + 22.9946 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 94.4041%. Conv + 0.0175992 GFLOP. 3.85829%. SpatialBN + 0.002561 GFLOP. 0.561449%. FC + 0.00210961 GFLOP. 0.46249%. Mul + 0.00173891 GFLOP. 0.381223%. Add + 0.00151626 GFLOP. 0.33241%. Div + 0 GFLOP. 0%. Relu + 0.456141 GFLOP in Total +Feature Memory Read per operator type: + 34.7354 MB. 36.4363%. Conv + 17.7944 MB. 18.6658%. SpatialBN + 14.5035 MB. 15.2137%. Mul + 9.25778 MB. 9.71113%. Relu + 7.84641 MB. 8.23064%. Add + 6.06516 MB. 6.36216%. Div + 5.12912 MB. 5.38029%. FC + 95.3317 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 26.7264%. Conv + 17.5992 MB. 26.6878%. SpatialBN + 9.25778 MB. 14.0387%. Relu + 8.43843 MB. 12.7962%. Mul + 6.95565 MB. 10.5477%. Add + 6.06502 MB. 9.19713%. Div + 0.004 MB. 0.00606568%. FC + 65.9447 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.1564%. Conv + 5.124 MB. 23.3979%. FC + 0.0976 MB. 0.445674%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8994 MB in Total + +``` +### Optimized + +``` +Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 +Time per operator type: + 17.146 ms. 78.8965%. Conv + 1.38453 ms. 6.37084%. Add + 1.30991 ms. 6.02749%. Div + 0.685417 ms. 3.15391%. Mul + 0.532589 ms. 2.45068%. Relu + 0.418263 ms. 1.92461%. Clip + 0.15128 ms. 0.696106%. FC + 0.102065 ms. 0.469648%. AveragePool + 0.0022143 ms. 0.010189%. Squeeze + 21.7323 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 98.1927%. Conv + 0.002561 GFLOP. 0.583981%. FC + 0.00210961 GFLOP. 0.481051%. Mul + 0.00173891 GFLOP. 0.396522%. Add + 0.00151626 GFLOP. 0.34575%. Div + 0 GFLOP. 0%. Relu + 0.438542 GFLOP in Total +Feature Memory Read per operator type: + 34.7842 MB. 44.833%. Conv + 14.5035 MB. 18.6934%. Mul + 9.25778 MB. 11.9323%. Relu + 7.84641 MB. 10.1132%. Add + 6.06516 MB. 7.81733%. Div + 5.12912 MB. 6.61087%. FC + 77.5861 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 36.4556%. Conv + 9.25778 MB. 19.1492%. Relu + 8.43843 MB. 17.4544%. Mul + 6.95565 MB. 14.3874%. Add + 6.06502 MB. 12.5452%. Div + 0.004 MB. 0.00827378%. FC + 48.3455 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.4973%. Conv + 5.124 MB. 23.5027%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8018 MB in Total + +``` + +## MnasNet-A1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 +Time per operator type: + 24.4656 ms. 79.0905%. Conv + 4.14958 ms. 13.4144%. SpatialBN + 1.60598 ms. 5.19169%. Relu + 0.295219 ms. 0.95436%. Mul + 0.187609 ms. 0.606486%. FC + 0.120556 ms. 0.389724%. AveragePool + 0.09036 ms. 0.292109%. Add + 0.015727 ms. 0.050841%. Sigmoid + 0.00306205 ms. 0.00989875%. Squeeze + 30.9337 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 95.6434%. Conv + 0.0248873 GFLOP. 3.8355%. SpatialBN + 0.002561 GFLOP. 0.394688%. FC + 0.000597408 GFLOP. 0.0920695%. Mul + 0.000222656 GFLOP. 0.0343146%. Add + 0 GFLOP. 0%. Relu + 0.648867 GFLOP in Total +Feature Memory Read per operator type: + 35.5457 MB. 38.4109%. Conv + 25.1552 MB. 27.1829%. SpatialBN + 22.5235 MB. 24.339%. Relu + 5.12912 MB. 5.54256%. FC + 2.40586 MB. 2.59978%. Mul + 1.78125 MB. 1.92483%. Add + 92.5406 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 32.9424%. Conv + 24.8873 MB. 32.92%. SpatialBN + 22.5235 MB. 29.7932%. Relu + 2.38963 MB. 3.16092%. Mul + 0.890624 MB. 1.17809%. Add + 0.004 MB. 0.00529106%. FC + 75.5993 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.1459%. Conv + 5.124 MB. 32.9917%. FC + 0.133952 MB. 0.86247%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.5312 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 +Time per operator type: + 22.0547 ms. 91.1375%. Conv + 1.49096 ms. 6.16116%. Relu + 0.253417 ms. 1.0472%. Mul + 0.18506 ms. 0.76473%. FC + 0.112942 ms. 0.466717%. AveragePool + 0.086769 ms. 0.358559%. Add + 0.0127889 ms. 0.0528479%. Sigmoid + 0.0027346 ms. 0.0113003%. Squeeze + 24.1994 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 99.4581%. Conv + 0.002561 GFLOP. 0.41043%. FC + 0.000597408 GFLOP. 0.0957417%. Mul + 0.000222656 GFLOP. 0.0356832%. Add + 0 GFLOP. 0%. Relu + 0.623979 GFLOP in Total +Feature Memory Read per operator type: + 35.6127 MB. 52.7968%. Conv + 22.5235 MB. 33.3917%. Relu + 5.12912 MB. 7.60406%. FC + 2.40586 MB. 3.56675%. Mul + 1.78125 MB. 2.64075%. Add + 67.4524 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 49.1092%. Conv + 22.5235 MB. 44.4145%. Relu + 2.38963 MB. 4.71216%. Mul + 0.890624 MB. 1.75624%. Add + 0.004 MB. 0.00788768%. FC + 50.712 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.7213%. Conv + 5.124 MB. 33.2787%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.3972 MB in Total +``` +## MnasNet-B1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 +Time per operator type: + 29.1121 ms. 83.3081%. Conv + 4.14959 ms. 11.8746%. SpatialBN + 1.35823 ms. 3.88675%. Relu + 0.186188 ms. 0.532802%. FC + 0.116244 ms. 0.332647%. Add + 0.018641 ms. 0.0533437%. AveragePool + 0.0040904 ms. 0.0117052%. Squeeze + 34.9451 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 96.2088%. Conv + 0.0218266 GFLOP. 3.35303%. SpatialBN + 0.002561 GFLOP. 0.393424%. FC + 0.000291648 GFLOP. 0.0448034%. Add + 0 GFLOP. 0%. Relu + 0.650951 GFLOP in Total +Feature Memory Read per operator type: + 34.4354 MB. 41.3788%. Conv + 22.1299 MB. 26.5921%. SpatialBN + 19.1923 MB. 23.0622%. Relu + 5.12912 MB. 6.16333%. FC + 2.33318 MB. 2.80364%. Add + 83.2199 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 34.0955%. Conv + 21.8266 MB. 34.0955%. SpatialBN + 19.1923 MB. 29.9805%. Relu + 1.16659 MB. 1.82234%. Add + 0.004 MB. 0.00624844%. FC + 64.016 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 69.9104%. Conv + 5.124 MB. 29.2245%. FC + 0.15168 MB. 0.865099%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.5332 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 +Time per operator type: + 24.9888 ms. 94.0962%. Conv + 1.26147 ms. 4.75011%. Relu + 0.176234 ms. 0.663619%. FC + 0.113309 ms. 0.426672%. Add + 0.0138708 ms. 0.0522311%. AveragePool + 0.00295685 ms. 0.0111341%. Squeeze + 26.5566 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 99.5466%. Conv + 0.002561 GFLOP. 0.407074%. FC + 0.000291648 GFLOP. 0.0463578%. Add + 0 GFLOP. 0%. Relu + 0.629124 GFLOP in Total +Feature Memory Read per operator type: + 34.5112 MB. 56.4224%. Conv + 19.1923 MB. 31.3775%. Relu + 5.12912 MB. 8.3856%. FC + 2.33318 MB. 3.81452%. Add + 61.1658 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 51.7346%. Conv + 19.1923 MB. 45.4908%. Relu + 1.16659 MB. 2.76513%. Add + 0.004 MB. 0.00948104%. FC + 42.1895 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 70.5205%. Conv + 5.124 MB. 29.4795%. FC + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.3816 MB in Total +``` diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/LICENSE b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/LICENSE new file mode 100644 index 000000000..80e7d1550 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Ross Wightman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/README.md b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/README.md new file mode 100644 index 000000000..463368280 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/README.md @@ -0,0 +1,323 @@ +# (Generic) EfficientNets for PyTorch + +A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. + +All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) + +## What's New + +### Aug 19, 2020 +* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) +* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) +* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX +* ONNX runtime based validation script added +* activations (mostly) brought in sync with `timm` equivalents + + +### April 5, 2020 +* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite + * 3.5M param MobileNet-V2 100 @ 73% + * 4.5M param MobileNet-V2 110d @ 75% + * 6.1M param MobileNet-V2 140 @ 76.5% + * 5.8M param MobileNet-V2 120d @ 77.3% + +### March 23, 2020 + * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 + * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior + +### Feb 12, 2020 + * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) + * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. + * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) + +### Jan 22, 2020 + * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) + * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict + * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues + +### Nov 22, 2019 + * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different + preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. + +### Nov 15, 2019 + * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights + * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine + +### Oct 30, 2019 + * Many of the models will now work with torch.jit.script, MixNet being the biggest exception + * Improved interface for enabling torchscript or ONNX export compatible modes (via config) + * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn + * Activation factory to select best version of activation by name or override one globally + * Add pretrained checkpoint load helper that handles input conv and classifier changes + +### Oct 27, 2019 + * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base + * Switch activations and global pooling to modules + * Add memory-efficient Swish/Mish impl + * Add as_sequential() method to all models and allow as an argument in entrypoint fns + * Move MobileNetV3 into own file since it has a different head + * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here + +## Models + +Implemented models include: + * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) + * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) + * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) + * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) + * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * MixNet (https://arxiv.org/abs/1907.09595) + * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) + * MobileNet-V3 (https://arxiv.org/abs/1905.02244) + * FBNet-C (https://arxiv.org/abs/1812.03443) + * Single-Path NAS (https://arxiv.org/abs/1904.02877) + +I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. + +## Pretrained + +I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models + + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | +|---|---|---|---|---|---|---|---| +| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | +| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | +| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | +| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | +| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | +| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | +| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | +| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | +| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | +| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | +| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | +| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | +| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | +| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | +| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | +| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | +| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | +| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | + + +More pretrained models to come... + + +## Ported Weights + +The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. + +**IMPORTANT:** +* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. +* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. + +To run validation for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` + +To run validation w/ TF preprocessing for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` + +To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | +|---|---|---|---|---|---|---| +| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | +| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | +| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | +| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | +| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | +| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | +| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | +| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | +| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | +| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | +| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | +| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | +| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | +| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | +| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | +| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | +| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | +| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | +| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | +| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | +| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | +| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | +| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | +| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | +| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | +| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | +| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | +| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | +| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | +| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | + + +*tfp models validated with `tf-preprocessing` pipeline + +Google tf and tflite weights ported from official Tensorflow repositories +* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet +* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet +* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet + +## Usage + +### Environment + +All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. + +Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. + +PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. + +I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: +``` +conda create -n torch-env +conda activate torch-env +conda install -c pytorch pytorch torchvision cudatoolkit=10.2 +``` + +### PyTorch Hub + +Models can be accessed via the PyTorch Hub API + +``` +>>> torch.hub.list('rwightman/gen-efficientnet-pytorch') +['efficientnet_b0', ...] +>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) +>>> model.eval() +>>> output = model(torch.randn(1,3,224,224)) +``` + +### Pip +This package can be installed via pip. + +Install (after conda env/install): +``` +pip install geffnet +``` + +Eval use: +``` +>>> import geffnet +>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) +>>> m.eval() +``` + +Train use: +``` +>>> import geffnet +>>> # models can also be created by using the entrypoint directly +>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) +>>> m.train() +``` + +Create in a nn.Sequential container, for fast.ai, etc: +``` +>>> import geffnet +>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) +``` + +### Exporting + +Scripts are included to +* export models to ONNX (`onnx_export.py`) +* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) +* validate with ONNX runtime (`onnx_validate.py`) +* convert ONNX model to Caffe2 (`onnx_to_caffe.py`) +* validate in Caffe2 (`caffe2_validate.py`) +* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) + +As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: +``` +python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx +python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx +``` + +These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible +export now requires additional args mentioned in the export script (not needed in earlier versions). + +#### Export Notes +1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. +2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. +3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. +3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. + + diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/__init__.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py new file mode 100644 index 000000000..ca60ac711 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py @@ -0,0 +1,5 @@ +from .gen_efficientnet import * +from .mobilenetv3 import * +from .model_factory import create_model +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable +from .activations import * diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py new file mode 100644 index 000000000..813421a74 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py @@ -0,0 +1,137 @@ +from geffnet import config +from geffnet.activations.activations_me import * +from geffnet.activations.activations_jit import * +from geffnet.activations.activations import * +import torch + +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_swish=hard_swish_me, + hard_sigmoid_jit=hard_sigmoid_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_swish=HardSwishMe, + hard_sigmoid=HardSigmoidMe +) + +_OVERRIDE_FN = dict() +_OVERRIDE_LAYER = dict() + + +def add_override_act_fn(name, fn): + global _OVERRIDE_FN + _OVERRIDE_FN[name] = fn + + +def update_override_act_fn(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_FN + _OVERRIDE_FN.update(overrides) + + +def clear_override_act_fn(): + global _OVERRIDE_FN + _OVERRIDE_FN = dict() + + +def add_override_act_layer(name, fn): + _OVERRIDE_LAYER[name] = fn + + +def update_override_act_layer(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_LAYER + _OVERRIDE_LAYER.update(overrides) + + +def clear_override_act_layer(): + global _OVERRIDE_LAYER + _OVERRIDE_LAYER = dict() + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_FN: + return _OVERRIDE_FN[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_FN_ME: + # If not exporting or scripting the model, first look for a memory optimized version + # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin + return _ACT_FN_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_LAYER: + return _OVERRIDE_LAYER[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py new file mode 100644 index 000000000..bdea692d1 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py @@ -0,0 +1,102 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py new file mode 100644 index 000000000..7176b05e7 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py @@ -0,0 +1,79 @@ +""" Activations (jit) + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', + 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py new file mode 100644 index 000000000..e91df5a50 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py @@ -0,0 +1,174 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', + 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + + Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py new file mode 100644 index 000000000..27d5307fd --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py @@ -0,0 +1,123 @@ +""" Global layer config state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def layer_config_kwargs(kwargs): + """ Consume config kwargs and return contextmgr obj """ + return set_layer_config( + scriptable=kwargs.pop('scriptable', None), + exportable=kwargs.pop('exportable', None), + no_jit=kwargs.pop('no_jit', None)) diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py new file mode 100644 index 000000000..d8467460c --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py @@ -0,0 +1,304 @@ +""" Conv2D w/ SAME padding, CondConv, MixedConv + +A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and +MobileNetV3 models that maintain weight compatibility with original Tensorflow models. + +Copyright 2020 Ross Wightman +""" +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import * + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _same_pad_arg(input_size, kernel_size, stride, dilation): + ih, iw = input_size + kh, kw = kernel_size + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + if self.pad is not None: + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + if is_exportable(): + assert not is_scriptable() + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py new file mode 100644 index 000000000..0343e3f44 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py @@ -0,0 +1,683 @@ +""" EfficientNet / MobileNetV3 Blocks and Builder + +Copyright 2020 Ross Wightman +""" +import re +from copy import deepcopy + +from .conv2d_layers import * +from geffnet.activations import * + +__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', + 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', + 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', + 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' +] + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +# +# PyTorch defaults are momentum = .1, eps = 1e-5 +# +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, # None == use containing block's activation layer + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v: int, divisor: int = 8, min_value: int = None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class SqueezeExcite(nn.Module): + + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion + factor of 1.0. This is an alternative to having a IR with optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs: int = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() # for jit.script compat + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class EfficientNetBuilder: + """ Build Trunk Blocks for Efficient/Mobile Networks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = EdgeResidual(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) + else: + raise AssertionError('Uknkown block type (%s) while building model.' % bt) + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for _stack_idx, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return blocks + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + raise AssertionError('Unknown block type (%s)' % block_type) + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +def initialize_weight_goog(m, n='', fix_group_fanout=True): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def initialize_weight_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py new file mode 100644 index 000000000..cd170d4cc --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py @@ -0,0 +1,1450 @@ +""" Generic Efficient Networks + +A generic MobileNet class with building blocks to support a variety of models: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* EfficientNet-Lite + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .config import layer_config_kwargs, is_scriptable +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', + 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', + 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', + 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', + 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', + 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', + 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', + 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', + 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', + 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', + 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', + 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', + 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', + 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', + 'tf_efficientnet_lite4', + 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] + + +model_urls = { + 'mnasnet_050': None, + 'mnasnet_075': None, + 'mnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + 'mnasnet_140': None, + 'mnasnet_small': None, + + 'semnasnet_050': None, + 'semnasnet_075': None, + 'semnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + 'semnasnet_140': None, + + 'mobilenetv2_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + 'mobilenetv2_110d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + 'mobilenetv2_120d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + 'mobilenetv2_140': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + + 'fbnetc_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + 'spnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + + 'efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + 'efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + 'efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + 'efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + 'efficientnet_b4': None, + 'efficientnet_b5': None, + 'efficientnet_b6': None, + 'efficientnet_b7': None, + 'efficientnet_b8': None, + 'efficientnet_l2': None, + + 'efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + 'efficientnet_em': None, + 'efficientnet_el': None, + + 'efficientnet_cc_b0_4e': None, + 'efficientnet_cc_b0_8e': None, + 'efficientnet_cc_b1_8e': None, + + 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + 'efficientnet_lite1': None, + 'efficientnet_lite2': None, + 'efficientnet_lite3': None, + 'efficientnet_lite4': None, + + 'tf_efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + 'tf_efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + 'tf_efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + 'tf_efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + 'tf_efficientnet_b4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + 'tf_efficientnet_b5': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b6': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + 'tf_efficientnet_b7': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b8': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + + 'tf_efficientnet_b0_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + 'tf_efficientnet_b1_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + 'tf_efficientnet_b2_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + 'tf_efficientnet_b3_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + 'tf_efficientnet_b4_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + 'tf_efficientnet_b5_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + 'tf_efficientnet_b6_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + 'tf_efficientnet_b7_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + 'tf_efficientnet_b8_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + + 'tf_efficientnet_b0_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + 'tf_efficientnet_b1_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + 'tf_efficientnet_b2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + 'tf_efficientnet_b3_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + 'tf_efficientnet_b4_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + 'tf_efficientnet_b5_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + 'tf_efficientnet_b6_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + 'tf_efficientnet_b7_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + 'tf_efficientnet_l2_ns_475': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + 'tf_efficientnet_l2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + + 'tf_efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + 'tf_efficientnet_em': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + 'tf_efficientnet_el': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + + 'tf_efficientnet_cc_b0_4e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + 'tf_efficientnet_cc_b0_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + 'tf_efficientnet_cc_b1_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + + 'tf_efficientnet_lite0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + 'tf_efficientnet_lite1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + 'tf_efficientnet_lite2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + 'tf_efficientnet_lite3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + 'tf_efficientnet_lite4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + + 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + + 'tf_mixnet_s': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + 'tf_mixnet_m': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + 'tf_mixnet_l': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', +} + + +class GenEfficientNet(nn.Module): + """ Generic EfficientNets + + An implementation of mobile optimized networks that covers: + * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) + * MixNet (Small, Medium, and Large, XL) + * MNASNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + weight_init='goog'): + super(GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, + pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) + self.bn2 = norm_layer(num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(num_features, num_classes) + + for n, m in self.named_modules(): + if weight_init == 'goog': + initialize_weight_goog(m, n) + else: + initialize_weight_default(m, n) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.conv_head, self.bn2, self.act2, + self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = GenEfficientNet(**model_kwargs) + if pretrained: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2. """ + # NOTE for train, drop_rate should be 0.5 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + # NOTE for train set drop_rate=0.2 + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py new file mode 100644 index 000000000..30c63b2ce --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py @@ -0,0 +1,71 @@ +""" Checkpoint loading / state_dict helpers +Copyright 2020 Ross Wightman +""" +import torch +import os +from collections import OrderedDict +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError + + +def load_pretrained(model, url, filter_fn=None, strict=True): + if not url: + print("=> Warning: Pretrained model URL is empty, using random initialization.") + return + + state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') + + input_conv = 'conv_stem' + classifier = 'classifier' + in_chans = getattr(model, input_conv).weight.shape[1] + num_classes = getattr(model, classifier).weight.shape[0] + + input_conv_weight = input_conv + '.weight' + pretrained_in_chans = state_dict[input_conv_weight].shape[1] + if in_chans != pretrained_in_chans: + if in_chans == 1: + print('=> Converting pretrained input conv {} from {} to 1 channel'.format( + input_conv_weight, pretrained_in_chans)) + conv1_weight = state_dict[input_conv_weight] + state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) + else: + print('=> Discarding pretrained input conv {} since input channel count != {}'.format( + input_conv_weight, pretrained_in_chans)) + del state_dict[input_conv_weight] + strict = False + + classifier_weight = classifier + '.weight' + pretrained_num_classes = state_dict[classifier_weight].shape[0] + if num_classes != pretrained_num_classes: + print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) + del state_dict[classifier_weight] + del state_dict[classifier + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py new file mode 100644 index 000000000..b5966c28f --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py @@ -0,0 +1,364 @@ +""" MobileNet-V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_act_fn, get_act_layer, HardSwish +from .config import layer_config_kwargs +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', + 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', + 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] + +model_urls = { + 'mobilenetv3_rw': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + 'mobilenetv3_large_075': None, + 'mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + 'mobilenetv3_large_minimal_100': None, + 'mobilenetv3_small_075': None, + 'mobilenetv3_small_100': None, + 'mobilenetv3_small_minimal_100': None, + 'tf_mobilenetv3_large_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + 'tf_mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + 'tf_mobilenetv3_large_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + 'tf_mobilenetv3_small_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + 'tf_mobilenetv3_small_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + 'tf_mobilenetv3_small_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', +} + + +class MobileNetV3(nn.Module): + """ MobileNet-V3 + + A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the + head convolution without a final batch-norm layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3, self).__init__() + self.drop_rate = drop_rate + + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if weight_init == 'goog': + initialize_weight_goog(m) + else: + initialize_weight_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.global_pool, self.conv_head, self.act2, + nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = MobileNetV3(**model_kwargs) + if pretrained and model_urls[variant]: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model (RW variant). + + Paper: https://arxiv.org/abs/1905.02244 + + This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the + eventual Tensorflow reference impl but has a few differences: + 1. This model has no bias on the head convolution + 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet + 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer + from their parent block + 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count + + Overall the changes are fairly minor and result in a very small parameter count difference and no + top-1/5 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, # one of my mistakes + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 large/small/minimal models. + + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, act_layer), + se_kwargs=dict( + act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet-V3 RW + Attn: See note in gen function for this variant. + """ + # NOTE for train set drop_rate=0.2 + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75""" + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large (Minimalistic) 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small (Minimalistic) 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py new file mode 100644 index 000000000..4d46ea8ba --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py @@ -0,0 +1,27 @@ +from .config import set_layer_config +from .helpers import load_checkpoint + +from .gen_efficientnet import * +from .mobilenetv3 import * + + +def create_model( + model_name='mnasnet_100', + pretrained=None, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + + model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**model_kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path and not pretrained: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py new file mode 100644 index 000000000..a6221b3de --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py @@ -0,0 +1 @@ +__version__ = '1.0.2' diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/hubconf.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/hubconf.py new file mode 100644 index 000000000..fd1915086 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/hubconf.py @@ -0,0 +1,81 @@ +dependencies = ['torch', 'math'] + +from geffnet import efficientnet_b0 +from geffnet import efficientnet_b1 +from geffnet import efficientnet_b2 +from geffnet import efficientnet_b3 +from geffnet import efficientnet_es +from geffnet import efficientnet_lite0 + +from geffnet import mixnet_s +from geffnet import mixnet_m +from geffnet import mixnet_l +from geffnet import mixnet_xl + +from geffnet import mobilenetv2_100 +from geffnet import mobilenetv2_110d +from geffnet import mobilenetv2_120d +from geffnet import mobilenetv2_140 + +from geffnet import mobilenetv3_large_100 +from geffnet import mobilenetv3_rw +from geffnet import mnasnet_a1 +from geffnet import mnasnet_b1 +from geffnet import fbnetc_100 +from geffnet import spnasnet_100 + +from geffnet import tf_efficientnet_b0 +from geffnet import tf_efficientnet_b1 +from geffnet import tf_efficientnet_b2 +from geffnet import tf_efficientnet_b3 +from geffnet import tf_efficientnet_b4 +from geffnet import tf_efficientnet_b5 +from geffnet import tf_efficientnet_b6 +from geffnet import tf_efficientnet_b7 +from geffnet import tf_efficientnet_b8 + +from geffnet import tf_efficientnet_b0_ap +from geffnet import tf_efficientnet_b1_ap +from geffnet import tf_efficientnet_b2_ap +from geffnet import tf_efficientnet_b3_ap +from geffnet import tf_efficientnet_b4_ap +from geffnet import tf_efficientnet_b5_ap +from geffnet import tf_efficientnet_b6_ap +from geffnet import tf_efficientnet_b7_ap +from geffnet import tf_efficientnet_b8_ap + +from geffnet import tf_efficientnet_b0_ns +from geffnet import tf_efficientnet_b1_ns +from geffnet import tf_efficientnet_b2_ns +from geffnet import tf_efficientnet_b3_ns +from geffnet import tf_efficientnet_b4_ns +from geffnet import tf_efficientnet_b5_ns +from geffnet import tf_efficientnet_b6_ns +from geffnet import tf_efficientnet_b7_ns +from geffnet import tf_efficientnet_l2_ns_475 +from geffnet import tf_efficientnet_l2_ns + +from geffnet import tf_efficientnet_es +from geffnet import tf_efficientnet_em +from geffnet import tf_efficientnet_el + +from geffnet import tf_efficientnet_cc_b0_4e +from geffnet import tf_efficientnet_cc_b0_8e +from geffnet import tf_efficientnet_cc_b1_8e + +from geffnet import tf_efficientnet_lite0 +from geffnet import tf_efficientnet_lite1 +from geffnet import tf_efficientnet_lite2 +from geffnet import tf_efficientnet_lite3 +from geffnet import tf_efficientnet_lite4 + +from geffnet import tf_mixnet_s +from geffnet import tf_mixnet_m +from geffnet import tf_mixnet_l + +from geffnet import tf_mobilenetv3_large_075 +from geffnet import tf_mobilenetv3_large_100 +from geffnet import tf_mobilenetv3_large_minimal_100 +from geffnet import tf_mobilenetv3_small_075 +from geffnet import tf_mobilenetv3_small_100 +from geffnet import tf_mobilenetv3_small_minimal_100 diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/requirements.txt b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/requirements.txt new file mode 100644 index 000000000..ac3ffc13b --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.2.0 +torchvision>=0.4.0 diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/setup.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/setup.py new file mode 100644 index 000000000..83388db37 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/setup.py @@ -0,0 +1,48 @@ +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) +__version__ = '0.0.0' + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('geffnet/version.py').read()) +setup( + name='geffnet', + version=__version__, + description='(Generic) EfficientNets for PyTorch', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/rwightman/gen-efficientnet-pytorch', + author='Ross Wightman', + author_email='hello@rwightman.com', + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + + # Note that this is a string of words separated by whitespace, not a list. + keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', + packages=find_packages(exclude=['data']), + install_requires=['torch >= 1.4', 'torchvision'], + python_requires='>=3.6', +) diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/utils.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/utils.py new file mode 100644 index 000000000..d327e8bd8 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/utils.py @@ -0,0 +1,52 @@ +import os + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + diff --git a/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/validate.py b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/validate.py new file mode 100644 index 000000000..bed37a389 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/efficientnet_repo/validate.py @@ -0,0 +1,165 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import time +import torch +import torch.nn as nn +import torch.nn.parallel +from contextlib import suppress + +import geffnet +from data import Dataset, create_loader, resolve_data_config +from utils import accuracy, AverageMeter + +has_native_amp = False +try: + if torch.cuda.amp.autocast is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', + help='') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use native Torch AMP mixed precision.') + + +def main(): + args = parser.parse_args() + + if not args.checkpoint and not args.pretrained: + args.pretrained = True + + amp_autocast = suppress # do nothing + if args.amp: + if not has_native_amp: + print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") + else: + amp_autocast = torch.cuda.amp.autocast + + # create model + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + scriptable=args.torchscript) + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + data_config = resolve_data_config(model, args) + + criterion = nn.CrossEntropyLoss() + + if not args.no_cuda: + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=not args.no_cuda, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + end = time.time() + for i, (input, target) in enumerate(loader): + if not args.no_cuda: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +if __name__ == '__main__': + main() diff --git a/modules/control/proc/normalbae/nets/submodules/encoder.py b/modules/control/proc/normalbae/nets/submodules/encoder.py new file mode 100644 index 000000000..21ee626a2 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/encoder.py @@ -0,0 +1,28 @@ +import os +import torch +import torch.nn as nn + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') + basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') + + # Remove last layer + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if k == 'blocks': + for _ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features diff --git a/modules/control/proc/normalbae/nets/submodules/submodules.py b/modules/control/proc/normalbae/nets/submodules/submodules.py new file mode 100644 index 000000000..fdf12e133 --- /dev/null +++ b/modules/control/proc/normalbae/nets/submodules/submodules.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +######################################################################################################################## + + +# Upsample + BatchNorm +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleBN, self).__init__() + + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Upsample + GroupNorm + Weight Standardization +class UpSampleGN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleGN, self).__init__() + + self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU(), + Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Conv2d with weight standardization +class Conv2d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# normalize +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out + + +# uncertainty-guided sampling (only used during training) +def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): + device = init_normal.device + B, _, H, W = init_normal.shape + N = int(sampling_ratio * H * W) + beta = beta + + # uncertainty map + uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W + + # gt_invalid_mask (B, H, W) + if gt_norm_mask is not None: + gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') + gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 + uncertainty_map[gt_invalid_mask] = -1e4 + + # (B, H*W) + _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) + + # importance sampling + if int(beta * N) > 0: + importance = idx[:, :int(beta * N)] # B, beta*N + + # remaining + remaining = idx[:, int(beta * N):] # B, H*W - beta*N + + # coverage + num_coverage = N - int(beta * N) + + if num_coverage <= 0: + samples = importance + else: + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = torch.cat((importance, coverage), dim=1) # B, N + + else: + # remaining + remaining = idx[:, :] # B, H*W + + # coverage + num_coverage = N + + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = coverage + + # point coordinates + rows_int = samples // W # 0 for first row, H-1 for last row + rows_float = rows_int / float(H-1) # 0 to 1.0 + rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 + + cols_int = samples % W # 0 for first column, W-1 for last column + cols_float = cols_int / float(W-1) # 0 to 1.0 + cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 + + point_coords = torch.zeros(B, 1, N, 2) + point_coords[:, 0, :, 0] = cols_float # x coord + point_coords[:, 0, :, 1] = rows_float # y coord + point_coords = point_coords.to(device) + return point_coords, rows_int, cols_int diff --git a/modules/control/proc/openpose/LICENSE b/modules/control/proc/openpose/LICENSE new file mode 100644 index 000000000..6f60b76d3 --- /dev/null +++ b/modules/control/proc/openpose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/modules/control/proc/openpose/__init__.py b/modules/control/proc/openpose/__init__.py new file mode 100644 index 000000000..398b7ec40 --- /dev/null +++ b/modules/control/proc/openpose/__init__.py @@ -0,0 +1,233 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) +# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) +# This preprocessor is licensed by CMU for non-commercial use only. + + +import os + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import json +import warnings +from typing import Callable, List, NamedTuple, Tuple, Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from . import util +from .body import Body, BodyResult, Keypoint +from .face import Face +from .hand import Hand + +HandResult = List[Keypoint] +FaceResult = List[Keypoint] + +class PoseResult(NamedTuple): + body: BodyResult + left_hand: Union[HandResult, None] + right_hand: Union[HandResult, None] + face: Union[FaceResult, None] + +def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): + """ + Draw the detected poses on an empty canvas. + + Args: + poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. + H (int): The height of the canvas. + W (int): The width of the canvas. + draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. + draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. + draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. + + Returns: + numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. + """ + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + for pose in poses: + if draw_body: + canvas = util.draw_bodypose(canvas, pose.body.keypoints) + + if draw_hand: + canvas = util.draw_handpose(canvas, pose.left_hand) + canvas = util.draw_handpose(canvas, pose.right_hand) + + if draw_face: + canvas = util.draw_facepose(canvas, pose.face) + + return canvas + + +class OpenposeDetector: + """ + A class for detecting human poses in images using the Openpose model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): + self.body_estimation = body_estimation + self.hand_estimation = hand_estimation + self.face_estimation = face_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None): + + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/body_pose_model.pth" + hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = "lllyasviel/Annotators" + else: + filename = filename or "body_pose_model.pth" + hand_filename = hand_filename or "hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = pretrained_model_or_path + + if os.path.isdir(pretrained_model_or_path): + body_model_path = os.path.join(pretrained_model_or_path, filename) + hand_model_path = os.path.join(pretrained_model_or_path, hand_filename) + face_model_path = os.path.join(face_pretrained_model_or_path, face_filename) + else: + body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir) + face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir) + + body_estimation = Body(body_model_path) + hand_estimation = Hand(hand_model_path) + face_estimation = Face(face_model_path) + + return cls(body_estimation, hand_estimation, face_estimation) + + def to(self, device): + self.body_estimation.to(device) + self.hand_estimation.to(device) + self.face_estimation.to(device) + return self + + def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: + left_hand = None + right_hand = None + H, W, _ = oriImg.shape + for x, y, w, is_left in util.handDetect(body, oriImg): + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + + hand_result = [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + if is_left: + left_hand = hand_result + else: + right_hand = hand_result + + return left_hand, right_hand + + def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: + face = util.faceDetect(body, oriImg) + if face is None: + return None + + x, y, w = face + H, W, _ = oriImg.shape + heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + return [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + return None + + def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: + """ + Detect poses in the given image. + Args: + oriImg (numpy.ndarray): The input image for pose detection. + include_hand (bool, optional): Whether to include hand detection. Defaults to False. + include_face (bool, optional): Whether to include face detection. Defaults to False. + + Returns: + List[PoseResult]: A list of PoseResult objects containing the detected poses. + """ + oriImg = oriImg[:, :, ::-1].copy() + H, W, C = oriImg.shape + candidate, subset = self.body_estimation(oriImg) + bodies = self.body_estimation.format_body_result(candidate, subset) + + results = [] + for body in bodies: + left_hand, right_hand, face = (None,) * 3 + if include_hand: + left_hand, right_hand = self.detect_hands(body, oriImg) + if include_face: + face = self.detect_face(body, oriImg) + + results.append(PoseResult(BodyResult( + keypoints=[ + Keypoint( + x=keypoint.x / float(W), + y=keypoint.y / float(H) + ) if keypoint is not None else None + for keypoint in body.keypoints + ], + total_score=body.total_score, + total_parts=body.total_parts + ), left_hand, right_hand, face)) + + return results + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs): + if hand_and_face is not None: + warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + include_hand = hand_and_face + include_face = hand_and_face + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + H, W, C = input_image.shape + + poses = self.detect_poses(input_image, include_hand, include_face) + canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + + detected_map = canvas + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/openpose/body.py b/modules/control/proc/openpose/body.py new file mode 100644 index 000000000..01a23f3d8 --- /dev/null +++ b/modules/control/proc/openpose/body.py @@ -0,0 +1,254 @@ +import math +from typing import List, NamedTuple, Union +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter +from . import util +from .model import bodypose_model + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +class BodyResult(NamedTuple): + # Note: Using `Union` instead of `|` operator as the ladder is a Python + # 3.10 feature. + # Annotator code should be Python 3.8 Compatible, as controlnet repo uses + # Python 3.8 environment. + # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 + keypoints: List[Union[Keypoint, None]] + total_score: float + total_parts: int + + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImg): + device = next(iter(self.model.parameters())).device + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[x][1])), int(round(startend[x][0])), 0] for x in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[x][1])), int(round(startend[x][0])), 1] for x in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if len(connection) >= min(nA, nB): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + + @staticmethod + def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: + """ + Format the body results from the candidate and subset arrays into a list of BodyResult objects. + + Args: + candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id + for each body part. + subset (np.ndarray): An array of subsets containing indices to the candidate array for each + person detected. The last two columns of each row hold the total score and total parts + of the person. + + Returns: + List[BodyResult]: A list of BodyResult objects, where each object represents a person with + detected keypoints, total score, and total parts. + """ + return [ + BodyResult( + keypoints=[ + Keypoint( + x=candidate[candidate_index][0], + y=candidate[candidate_index][1], + score=candidate[candidate_index][2], + id=candidate[candidate_index][3] + ) if candidate_index != -1 else None + for candidate_index in person[:18].astype(int) + ], + total_score=person[18], + total_parts=person[19] + ) + for person in subset + ] diff --git a/modules/control/proc/openpose/face.py b/modules/control/proc/openpose/face.py new file mode 100644 index 000000000..e8e34451c --- /dev/null +++ b/modules/control/proc/openpose/face.py @@ -0,0 +1,360 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init +from torchvision.transforms import ToPILImage, ToTensor + +from . import util + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, face_img): + device = next(iter(self.model.parameters())).device + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + x_data = x_data.to(device) + + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) diff --git a/modules/control/proc/openpose/hand.py b/modules/control/proc/openpose/hand.py new file mode 100644 index 000000000..78e00213e --- /dev/null +++ b/modules/control/proc/openpose/hand.py @@ -0,0 +1,89 @@ +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter +from skimage.measure import label + +from . import util +from .model import handpose_model + + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImgRaw): + device = next(iter(self.model.parameters())).device + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) diff --git a/modules/control/proc/openpose/model.py b/modules/control/proc/openpose/model.py new file mode 100644 index 000000000..cfa390c14 --- /dev/null +++ b/modules/control/proc/openpose/model.py @@ -0,0 +1,215 @@ +from collections import OrderedDict +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 diff --git a/modules/control/proc/openpose/util.py b/modules/control/proc/openpose/util.py new file mode 100644 index 000000000..0bba54583 --- /dev/null +++ b/modules/control/proc/openpose/util.py @@ -0,0 +1,386 @@ +from typing import List, Tuple, Union +import math +import numpy as np +import cv2 +from .body import BodyResult, Keypoint + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: + """ + Draw keypoints and limbs representing body pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. + keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + H, W, _C = canvas.shape + stickwidth = 4 + + limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], + [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], + [2, 1], [1, 15], [15, 17], [1, 16], + [16, 18], + ] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for (k1_index, k2_index), color in zip(limbSeq, colors): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + continue + + Y = np.array([keypoint1.x, keypoint2.x]) * float(W) + X = np.array([keypoint1.y, keypoint2.y]) * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) + + for keypoint, color in zip(keypoints, colors): + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + + return canvas + + +def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + import matplotlib as mpl + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, _C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + + x1 = int(k1.x * W) + y1 = int(k1.y * H) + x2 = int(k2.x * W) + y2 = int(k2.y * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints representing face pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, _C = canvas.shape + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: + """ + Detect hands in the input body pose keypoints and calculate the bounding box for each hand. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left + corner of the bounding box, the width (height) of the bounding box, and + a boolean flag indicating whether the hand is a left hand (True) or a + right hand (False). + + Notes: + - The width and height of the bounding boxes are equal since the network requires squared input. + - The minimum bounding box size is 20 pixels. + """ + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + left_shoulder = keypoints[5] + left_elbow = keypoints[6] + left_wrist = keypoints[7] + right_shoulder = keypoints[2] + right_elbow = keypoints[3] + right_wrist = keypoints[4] + + # if any of three not detected + has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) + has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) + if not (has_left or has_right): + return [] + + hands = [] + #left hand + if has_left: + hands.append([ + left_shoulder.x, left_shoulder.y, + left_elbow.x, left_elbow.y, + left_wrist.x, left_wrist.y, + True + ]) + # right hand + if has_right: + hands.append([ + right_shoulder.x, right_shoulder.y, + right_elbow.x, right_elbow.y, + right_wrist.x, right_wrist.y, + False + ]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: + x = 0 + if y < 0: + y = 0 + width1 = width + width2 = width + if x + width > image_width: + width1 = image_width - x + if y + width > image_height: + width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append((int(x), int(y), int(width), is_left)) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: + """ + Detect the face in the input body pose keypoints and calculate the bounding box for the face. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the + bounding box and the width (height) of the bounding box, or None if the + face is not detected or the bounding box width is less than 20 pixels. + + Notes: + - The width and height of the bounding box are equal. + - The minimum bounding box size is 20 pixels. + """ + # left right eye ear 14 15 16 17 + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + head = keypoints[0] + left_eye = keypoints[14] + right_eye = keypoints[15] + left_ear = keypoints[16] + right_ear = keypoints[17] + + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): + return None + + width = 0.0 + x0, y0 = head.x, head.y + + if left_eye is not None: + x1, y1 = left_eye.x, left_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if right_eye is not None: + x1, y1 = right_eye.x, right_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if left_ear is not None: + x1, y1 = left_ear.x, left_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if right_ear is not None: + x1, y1 = right_ear.x, right_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + return int(x), int(y), int(width) + else: + return None + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/modules/control/proc/pidi.py b/modules/control/proc/pidi.py new file mode 100644 index 000000000..4c661b923 --- /dev/null +++ b/modules/control/proc/pidi.py @@ -0,0 +1,83 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, nms, resize_image, safe_step +from .pidi_model import pidinet + + +class PidiNetDetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None): + filename = filename or "table5_pidinet.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + netNetwork = pidinet() + netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) + netNetwork.eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.netNetwork.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + assert input_image.ndim == 3 + input_image = input_image[:, :, ::-1].copy() + image_pidi = torch.from_numpy(input_image).float().to(device) + image_pidi = image_pidi / 255.0 + image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + edge = self.netNetwork(image_pidi)[-1] + edge = edge.cpu().numpy() + if apply_filter: + edge = edge > 0.5 + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge[0, 0] + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/pidi/LICENSE b/modules/control/proc/pidi/LICENSE new file mode 100644 index 000000000..913b6cf92 --- /dev/null +++ b/modules/control/proc/pidi/LICENSE @@ -0,0 +1,21 @@ +It is just for research purpose, and commercial use should be contacted with authors first. + +Copyright (c) 2021 Zhuo Su + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/modules/control/proc/pidi_model.py b/modules/control/proc/pidi_model.py new file mode 100644 index 000000000..16595b35a --- /dev/null +++ b/modules/control/proc/pidi_model.py @@ -0,0 +1,681 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, + } + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + else: + print('impossible to be here unless you force that') + return None + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride=stride + + self.stride=stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride=stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + #if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 #if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) + + +if __name__ == '__main__': + model = pidinet() + ckp = torch.load('table5_pidinet.pth')['state_dict'] + model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) + im = cv2.imread('examples/test_my/cat_v4.png') + im = img2tensor(im).unsqueeze(0)/255. + res = model(im)[-1] + res = res>0.5 + res = res.float() + res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8) + print(res.shape) + cv2.imwrite('edge.png', res) diff --git a/modules/control/proc/reference_sd15.py b/modules/control/proc/reference_sd15.py new file mode 100644 index 000000000..53f8602d3 --- /dev/null +++ b/modules/control/proc/reference_sd15.py @@ -0,0 +1,792 @@ +# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +import torch +from diffusers import StableDiffusionPipeline +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg +from diffusers.utils import PIL_INTERPOLATION +from diffusers.utils.torch_utils import randn_tensor + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> pipe = StableDiffusionReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class StableDiffusionReferencePipeline(StableDiffusionPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = (image - 0.5) / 0.5 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, ref_image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Modify self attention and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None, scale=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type != "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/modules/control/proc/reference_sdxl.py b/modules/control/proc/reference_sdxl.py new file mode 100644 index 000000000..0ccd9914f --- /dev/null +++ b/modules/control/proc/reference_sdxl.py @@ -0,0 +1,804 @@ +# Based on stable_diffusion_reference.py + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers import StableDiffusionXLPipeline +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UpBlock2D, +) +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.utils import PIL_INTERPOLATION +from diffusers.utils.torch_utils import randn_tensor + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16").to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 + + return height, width + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = (image - 0.5) / 0.5 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.stack(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device) + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if refimage.dtype != self.vae.dtype: + refimage = refimage.to(dtype=self.vae.dtype) + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." + + # 0. Default height and width to unet + # height, width = self._default_height_width(height, width, ref_image) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # 4. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # 7. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Modify self attebtion and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 10.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if output_type != "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/modules/control/proc/segment_anything/__init__.py b/modules/control/proc/segment_anything/__init__.py new file mode 100644 index 000000000..bcd7195c7 --- /dev/null +++ b/modules/control/proc/segment_anything/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from typing import Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from modules.control.util import HWC3, resize_image +from .automatic_mask_generator import SamAutomaticMaskGenerator +from .build_sam import sam_model_registry + + +class SamDetector: + def __init__(self, mask_generator: SamAutomaticMaskGenerator = None): + self.mask_generator = mask_generator + + @classmethod + def from_pretrained(cls, model_path, filename, model_type, cache_dir=None): + """ + Possible model_type : vit_h, vit_l, vit_b, vit_t + download weights from https://github.com/facebookresearch/segment-anything + """ + model_path = hf_hub_download(model_path, filename, cache_dir=cache_dir) + + sam = sam_model_registry[model_type](checkpoint=model_path) + + if torch.cuda.is_available(): + sam.to("cuda") + + mask_generator = SamAutomaticMaskGenerator(sam) + + return cls(mask_generator) + + + def show_anns(self, anns): + from numpy.random import default_rng + gen = default_rng() + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + for ann in sorted_anns: + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + for i in range(3): + img[:,:,i] = gen.integers(255, dtype=np.uint8) + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255))) + + return np.array(final_img, dtype=np.uint8) + + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> Image.Image: + if "image" in kwargs: + warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("image") + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + # Generate Masks + masks = self.mask_generator.generate(input_image) + # Create map + image_map = self.show_anns(masks) + + detected_map = image_map + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/segment_anything/automatic_mask_generator.py b/modules/control/proc/segment_anything/automatic_mask_generator.py new file mode 100644 index 000000000..a5029053e --- /dev/null +++ b/modules/control/proc/segment_anything/automatic_mask_generator.py @@ -0,0 +1,371 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore + + if min_mask_region_area > 0: + import cv2 # type: ignore + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/modules/control/proc/segment_anything/build_sam.py b/modules/control/proc/segment_anything/build_sam.py new file mode 100644 index 000000000..9a52c506b --- /dev/null +++ b/modules/control/proc/segment_anything/build_sam.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_t": build_sam_vit_t, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + diff --git a/modules/control/proc/segment_anything/modeling/__init__.py b/modules/control/proc/segment_anything/modeling/__init__.py new file mode 100644 index 000000000..7aa261b83 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer +from .tiny_vit_sam import TinyViT diff --git a/modules/control/proc/segment_anything/modeling/common.py b/modules/control/proc/segment_anything/modeling/common.py new file mode 100644 index 000000000..dc410acda --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/modules/control/proc/segment_anything/modeling/image_encoder.py b/modules/control/proc/segment_anything/modeling/image_encoder.py new file mode 100644 index 000000000..689972ea8 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/modules/control/proc/segment_anything/modeling/mask_decoder.py b/modules/control/proc/segment_anything/modeling/mask_decoder.py new file mode 100644 index 000000000..27e577120 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/modules/control/proc/segment_anything/modeling/prompt_encoder.py b/modules/control/proc/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 000000000..c3143f4f8 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/modules/control/proc/segment_anything/modeling/sam.py b/modules/control/proc/segment_anything/modeling/sam.py new file mode 100644 index 000000000..614fd7483 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/sam.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple, Union + +from .tiny_vit_sam import TinyViT +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: Union[ImageEncoderViT, TinyViT], + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = None, + pixel_std: List[float] = None, + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + if pixel_std is None: + pixel_std = [58.395, 57.12, 57.375] + if pixel_mean is None: + pixel_mean = [123.675, 116.28, 103.53] + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/modules/control/proc/segment_anything/modeling/tiny_vit_sam.py b/modules/control/proc/segment_anything/modeling/tiny_vit_sam.py new file mode 100644 index 000000000..165bd11bb --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/tiny_vit_sam.py @@ -0,0 +1,721 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576): + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=None, depths=None, + num_heads=None, + window_sizes=None, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + if window_sizes is None: + window_sizes = [7, 7, 14, 7] + if num_heads is None: + num_heads = [3, 6, 12, 24] + if depths is None: + depths = [2, 2, 6, 2] + if embed_dims is None: + embed_dims = [96, 192, 384, 768] + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + #print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) # noqa + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) # noqa + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + #x = self.norm_head(x) + #x = self.head(x) + return x + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/modules/control/proc/segment_anything/modeling/transformer.py b/modules/control/proc/segment_anything/modeling/transformer.py new file mode 100644 index 000000000..28fafea52 --- /dev/null +++ b/modules/control/proc/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/modules/control/proc/segment_anything/predictor.py b/modules/control/proc/segment_anything/predictor.py new file mode 100644 index 000000000..742a34ef1 --- /dev/null +++ b/modules/control/proc/segment_anything/predictor.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/modules/control/proc/segment_anything/utils/__init__.py b/modules/control/proc/segment_anything/utils/__init__.py new file mode 100644 index 000000000..5277f4615 --- /dev/null +++ b/modules/control/proc/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/modules/control/proc/segment_anything/utils/amg.py b/modules/control/proc/segment_anything/utils/amg.py new file mode 100644 index 000000000..be064071e --- /dev/null +++ b/modules/control/proc/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/modules/control/proc/segment_anything/utils/onnx.py b/modules/control/proc/segment_anything/utils/onnx.py new file mode 100644 index 000000000..103867faf --- /dev/null +++ b/modules/control/proc/segment_anything/utils/onnx.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/modules/control/proc/segment_anything/utils/transforms.py b/modules/control/proc/segment_anything/utils/transforms.py new file mode 100644 index 000000000..c08ba1e3d --- /dev/null +++ b/modules/control/proc/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/modules/control/proc/shuffle.py b/modules/control/proc/shuffle.py new file mode 100644 index 000000000..3ee285857 --- /dev/null +++ b/modules/control/proc/shuffle.py @@ -0,0 +1,100 @@ +import warnings +import random +import cv2 +import numpy as np +from PIL import Image + +from modules.control.util import HWC3, img2mask, make_noise_disk, resize_image + + +class ContentShuffleDetector: + def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + H, W, _C = input_image.shape + if h is None: + h = H + if w is None: + w = W + if f is None: + f = 256 + x = make_noise_disk(h, w, 1, f) * float(W - 1) + y = make_noise_disk(h, w, 1, f) * float(H - 1) + flow = np.concatenate([x, y], axis=2).astype(np.float32) + detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + + +class ColorShuffleDetector: + def __call__(self, img): + H, W, C = img.shape + F = np.random.randint(64, 384) # noqa + A = make_noise_disk(H, W, 3, F) + B = make_noise_disk(H, W, 3, F) + C = (A + B) / 2.0 + A = (C + (A - C) * 3.0).clip(0, 1) + B = (C + (B - C) * 3.0).clip(0, 1) + L = img.astype(np.float32) / 255.0 + Y = A * L + B * (1 - L) + Y -= np.min(Y, axis=(0, 1), keepdims=True) + Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5) + Y *= 255.0 + return Y.clip(0, 255).astype(np.uint8) + + +class GrayDetector: + def __call__(self, img): + eps = 1e-5 + X = img.astype(np.float32) + r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2] + kr, kg, kb = [random.random() + eps for _ in range(3)] + ks = kr + kg + kb + kr /= ks + kg /= ks + kb /= ks + Y = r * kr + g * kg + b * kb + Y = np.stack([Y] * 3, axis=2) + return Y.clip(0, 255).astype(np.uint8) + + +class DownSampleDetector: + def __call__(self, img, level=3, k=16.0): + h = img.astype(np.float32) + for _ in range(level): + h += np.random.normal(loc=0.0, scale=k, size=h.shape) # noqa + h = cv2.pyrDown(h) + for _ in range(level): + h = cv2.pyrUp(h) + h += np.random.normal(loc=0.0, scale=k, size=h.shape) # noqa + return h.clip(0, 255).astype(np.uint8) + + +class Image2MaskShuffleDetector: + def __init__(self, resolution=(640, 512)): + self.H, self.W = resolution + + def __call__(self, img): + m = img2mask(img, self.H, self.W) + m *= 255.0 + return m.clip(0, 255).astype(np.uint8) diff --git a/modules/control/proc/zoe/LICENSE b/modules/control/proc/zoe/LICENSE new file mode 100644 index 000000000..7a1e90d00 --- /dev/null +++ b/modules/control/proc/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/modules/control/proc/zoe/__init__.py b/modules/control/proc/zoe/__init__.py new file mode 100644 index 000000000..e140496a4 --- /dev/null +++ b/modules/control/proc/zoe/__init__.py @@ -0,0 +1,97 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image +import safetensors + +from modules.control.util import HWC3, resize_image +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.models.zoedepth_nk.zoedepth_nk_v1 import ZoeDepthNK +from .zoedepth.utils.config import get_config + + +class ZoeDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None): + filename = filename or "ZoeD_M12_N.pt" + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir) + + if model_type == "zoedepth": + model_cls = ZoeDepth + elif model_type == "zoedepth_nk": + model_cls = ZoeDepthNK + else: + raise ValueError(f"ZoeDepth unknown model type {model_type}") + conf = get_config(model_type, "infer") + model = model_cls.build_from_config(conf) + # model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) + if model_path.lower().endswith('.safetensors'): + model_dict = safetensors.torch.load_file(model_path, device='cpu') + else: + model_dict = torch.load(model_path, map_location=torch.device('cpu')) + if hasattr(model_dict, 'model'): + model_dict = model_dict['model'] + model.load_state_dict(model_dict, strict=False) + # timm compatibility issue + for b in model.core.core.pretrained.model.blocks: + b.drop_path = torch.nn.Identity() + model.eval() + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None, gamma_corrected=False): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_depth = input_image + image_depth = torch.from_numpy(image_depth).float().to(device) + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + + if gamma_corrected: + depth = np.power(depth, 2.2) + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = depth_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, _C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/modules/control/proc/zoe/zoedepth/__init__.py b/modules/control/proc/zoe/zoedepth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/zoe/zoedepth/models/__init__.py b/modules/control/proc/zoe/zoedepth/models/__init__.py new file mode 100644 index 000000000..5f2668792 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/__init__.py b/modules/control/proc/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 000000000..5f2668792 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 000000000..683bd0329 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,378 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + # print("Params passed to Resize transform:") + # print("\twidth: ", width) + # print("\theight: ", height) + # print("\tresize_target: ", resize_target) + # print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + # print("\tensure_multiple_of: ", ensure_multiple_of) + # print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (int(width), int(height)), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + # print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 000000000..277b5c11b --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/README.md b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 000000000..9568ea71c --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/__init__.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 000000000..43291563d --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 000000000..ab7458704 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,202 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + if hooks is None: + hooks = [0, 4, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [96, 192, 384, 768] + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 000000000..84287762c --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,109 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=None, + patch_grid=None +): + if patch_grid is None: + patch_grid = [14, 14] + if hooks is None: + hooks = [3, 11, 21] + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks is None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 000000000..a55c0a224 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,37 @@ +import timm +import torch.nn as nn +from .utils import activations, forward_default, get_activation +from ..external.next_vit.classification.nextvit import * # noqa + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=None, +): + if hooks is None: + hooks = [2, 6, 36, 39] + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks is None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 000000000..66850a0fb --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks is None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 000000000..ee917d836 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks is None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks is None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks is None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 000000000..2f0c1225a --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,56 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=None, + patch_grid=None +): + if patch_grid is None: + patch_grid = [96, 96] + if hooks is None: + hooks = [1, 1, 17, 1] + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 000000000..bed17f97d --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,253 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + raise AssertionError("wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'") + + return readout_oper + + +def make_backbone_default( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + if hooks is None: + hooks = [2, 5, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [96, 192, 384, 768] + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 000000000..71e864cdf --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + if hooks is None: + hooks = [2, 5, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [96, 192, 384, 768] + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=None, + size=None, + hooks=None, + vit_features=768, + patch_size=None, + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + if patch_size is None: + patch_size = [16, 16] + if hooks is None: + hooks = [0, 1, 8, 11] + if size is None: + size = [384, 384] + if features is None: + features = [256, 512, 768, 768] + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks is None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 000000000..5cf430239 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 000000000..998a94bda --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,441 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=None): + if in_features is None: + in_features = [96, 256, 512, 1024] + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + raise AssertionError + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 000000000..afb997051 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last is True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 000000000..8a9549778 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 000000000..cba1bcfff --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,130 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks=None): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + if blocks is None: + blocks = {"expand": True} + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] is True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last is True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 000000000..98cbc296a --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + raise AssertionError + + if "openvino" not in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if "openvino" not in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if "openvino" not in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if "openvino" not in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 000000000..350cbc116 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/modules/control/proc/zoe/zoedepth/models/builder.py b/modules/control/proc/zoe/zoedepth/models/builder.py new file mode 100644 index 000000000..39bad8d39 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = module.get_version + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/modules/control/proc/zoe/zoedepth/models/depth_model.py b/modules/control/proc/zoe/zoedepth/models/depth_model.py new file mode 100644 index 000000000..37bb610fd --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,150 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self, device='cpu'): + super().__init__() + self.device = device + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") diff --git a/modules/control/proc/zoe/zoedepth/models/layers/__init__.py b/modules/control/proc/zoe/zoedepth/models/layers/__init__.py new file mode 100644 index 000000000..c344f725c --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/layers/__init__.py @@ -0,0 +1,23 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat diff --git a/modules/control/proc/zoe/zoedepth/models/layers/attractor.py b/modules/control/proc/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 000000000..c2fe653ed --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/modules/control/proc/zoe/zoedepth/models/layers/dist_layers.py b/modules/control/proc/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 000000000..3208405df --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/modules/control/proc/zoe/zoedepth/models/layers/localbins_layers.py b/modules/control/proc/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 000000000..91d08de0f --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers diff --git a/modules/control/proc/zoe/zoedepth/models/layers/patch_transformer.py b/modules/control/proc/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 000000000..23386a068 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/modules/control/proc/zoe/zoedepth/models/model_io.py b/modules/control/proc/zoe/zoedepth/models/model_io.py new file mode 100644 index 000000000..c42f51641 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/model_io.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth/__init__.py b/modules/control/proc/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 000000000..8532e9b9b --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 000000000..3112ed78c --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 000000000..b51802aa4 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/modules/control/proc/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 000000000..1705442c6 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,252 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=None, attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + if n_attractors is None: + n_attractors = [16, 8, 4, 1] + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/__init__.py b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 000000000..61cd507ca --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 000000000..42bab2a3a --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 000000000..889b1e282 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..layers.patch_transformer import PatchTransformerEncoder +from ..model_io import load_state_from_resource + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=None, attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + if n_attractors is None: + n_attractors = [16, 8, 4, 1] + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError as e: + raise ValueError(f"bin_conf_name {bin_conf_name} not found in bin_confs") from e + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for _name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for _name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for _bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/modules/control/proc/zoe/zoedepth/utils/__init__.py b/modules/control/proc/zoe/zoedepth/utils/__init__.py new file mode 100644 index 000000000..5f2668792 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/modules/control/proc/zoe/zoedepth/utils/arg_utils.py b/modules/control/proc/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 000000000..8a3004ec3 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/modules/control/proc/zoe/zoedepth/utils/config.py b/modules/control/proc/zoe/zoedepth/utils/config.py new file mode 100644 index 000000000..24525d947 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all(isinstance(e, dtype) for e in config[key] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/modules/control/proc/zoe/zoedepth/utils/easydict/__init__.py b/modules/control/proc/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 000000000..fe47f0173 --- /dev/null +++ b/modules/control/proc/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and k not in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/modules/control/processors.py b/modules/control/processors.py new file mode 100644 index 000000000..a52356da2 --- /dev/null +++ b/modules/control/processors.py @@ -0,0 +1,222 @@ +import os +import time +import torch +from PIL import Image +from modules.shared import log +from modules.errors import display + +from modules.control.proc.hed import HEDdetector +from modules.control.proc.canny import CannyDetector +from modules.control.proc.edge import EdgeDetector +from modules.control.proc.lineart import LineartDetector +from modules.control.proc.lineart_anime import LineartAnimeDetector +from modules.control.proc.pidi import PidiNetDetector +from modules.control.proc.mediapipe_face import MediapipeFaceDetector +from modules.control.proc.shuffle import ContentShuffleDetector + +from modules.control.proc.leres import LeresDetector +from modules.control.proc.midas import MidasDetector +from modules.control.proc.mlsd import MLSDdetector +from modules.control.proc.normalbae import NormalBaeDetector +from modules.control.proc.openpose import OpenposeDetector +from modules.control.proc.dwpose import DWposeDetector +from modules.control.proc.segment_anything import SamDetector +from modules.control.proc.zoe import ZoeDetector + + +models = {} +cache_dir = 'models/control/processors' +debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +config = { + # pose models + 'OpenPose': {'class': OpenposeDetector, 'checkpoint': True, 'params': {'include_body': True, 'include_hand': False, 'include_face': False}}, + 'DWPose': {'class': DWposeDetector, 'checkpoint': False, 'model': 'Tiny', 'params': {'min_confidence': 0.3}}, + 'MediaPipe Face': {'class': MediapipeFaceDetector, 'checkpoint': False, 'params': {'max_faces': 1, 'min_confidence': 0.5}}, + # outline models + 'Canny': {'class': CannyDetector, 'checkpoint': False, 'params': {'low_threshold': 100, 'high_threshold': 200}}, + 'Edge': {'class': EdgeDetector, 'checkpoint': False, 'params': {'pf': True, 'mode': 'edge'}}, + 'LineArt Realistic': {'class': LineartDetector, 'checkpoint': True, 'params': {'coarse': False}}, + 'LineArt Anime': {'class': LineartAnimeDetector, 'checkpoint': True, 'params': {}}, + 'HED': {'class': HEDdetector, 'checkpoint': True, 'params': {'scribble': False, 'safe': False}}, + 'PidiNet': {'class': PidiNetDetector, 'checkpoint': True, 'params': {'scribble': False, 'safe': False, 'apply_filter': False}}, + # depth models + 'Midas Depth Hybrid': {'class': MidasDetector, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}}, + 'Leres Depth': {'class': LeresDetector, 'checkpoint': True, 'params': {'boost': False, 'thr_a':0, 'thr_b':0}}, + 'Zoe Depth': {'class': ZoeDetector, 'checkpoint': True, 'params': {'gamma_corrected': False}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_N.safetensors', 'model_type': "zoedepth"}}, + 'Normal Bae': {'class': NormalBaeDetector, 'checkpoint': True, 'params': {}}, + # segmentation models + 'SegmentAnything': {'class': SamDetector, 'checkpoint': True, 'model': 'Base', 'params': {}}, + # other models + 'MLSD': {'class': MLSDdetector, 'checkpoint': True, 'params': {'thr_v': 0.1, 'thr_d': 0.1}}, + 'Shuffle': {'class': ContentShuffleDetector, 'checkpoint': False, 'params': {}}, + # 'Midas Depth Large': {'class': MidasDetector, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}, 'load_config': {'pretrained_model_or_path': 'Intel/dpt-large', 'model_type': "dpt_large", 'filename': ''}}, + # 'Zoe Depth Zoe': {'class': ZoeDetector, 'checkpoint': True, 'params': {}}, + # 'Zoe Depth NK': {'class': ZoeDetector, 'checkpoint': True, 'params': {}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_NK.safetensors', 'model_type': "zoedepth_nk"}}, +} + + +def list_models(refresh=False): + global models # pylint: disable=global-statement + if not refresh and len(models) > 0: + return models + models = ['None'] + list(config) + debug(f'Control list processors: path={cache_dir} models={models}') + return models + + +def update_settings(*settings): + debug(f'Control settings: {settings}') + def update(what, val): + processor_id = what[0] + if len(what) == 2 and config[processor_id][what[1]] != val: + config[processor_id][what[1]] = val + config[processor_id]['dirty'] = True + log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') + elif len(what) == 3 and config[processor_id][what[1]][what[2]] != val: + config[processor_id][what[1]][what[2]] = val + config[processor_id]['dirty'] = True + log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') + elif len(what) == 4 and config[processor_id][what[1]][what[2]][what[3]] != val: + config[processor_id][what[1]][what[2]][what[3]] = val + config[processor_id]['dirty'] = True + log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') + + update(['HED', 'params', 'scribble'], settings[0]) + update(['Midas Depth Hybrid', 'params', 'bg_th'], settings[1]) + update(['Midas Depth Hybrid', 'params', 'depth_and_normal'], settings[2]) + update(['MLSD', 'params', 'thr_v'], settings[3]) + update(['MLSD', 'params', 'thr_d'], settings[4]) + update(['OpenPose', 'params', 'include_body'], settings[5]) + update(['OpenPose', 'params', 'include_hand'], settings[6]) + update(['OpenPose', 'params', 'include_face'], settings[7]) + update(['PidiNet', 'params', 'scribble'], settings[8]) + update(['PidiNet', 'params', 'apply_filter'], settings[9]) + update(['LineArt Realistic', 'params', 'coarse'], settings[10]) + update(['Leres Depth', 'params', 'boost'], settings[11]) + update(['Leres Depth', 'params', 'thr_a'], settings[12]) + update(['Leres Depth', 'params', 'thr_b'], settings[13]) + update(['MediaPipe Face', 'params', 'max_faces'], settings[14]) + update(['MediaPipe Face', 'params', 'min_confidence'], settings[15]) + update(['Canny', 'params', 'low_threshold'], settings[16]) + update(['Canny', 'params', 'high_threshold'], settings[17]) + update(['DWPose', 'model'], settings[18]) + update(['DWPose', 'params', 'min_confidence'], settings[19]) + update(['SegmentAnything', 'model'], settings[20]) + update(['Edge', 'params', 'pf'], settings[21]) + update(['Edge', 'params', 'mode'], settings[22]) + update(['Zoe Depth', 'params', 'gamma_corrected'], settings[23]) + + +class Processor(): + def __init__(self, processor_id: str = None, resize = True, load_config = None): + self.model = None + self.resize = resize + self.processor_id = processor_id + self.override = None # override input image + self.load_config = { 'cache_dir': cache_dir } + from_config = config.get(processor_id, {}).get('load_config', None) + if load_config is not None: + for k, v in load_config.items(): + self.load_config[k] = v + if from_config is not None: + for k, v in from_config.items(): + self.load_config[k] = v + if processor_id is not None: + self.load() + + def reset(self): + if self.model is not None: + log.debug(f'Control processor unloaded: id="{self.processor_id}"') + self.model = None + self.processor_id = None + self.override = None + + def load(self, processor_id: str = None) -> str: + try: + t0 = time.time() + processor_id = processor_id or self.processor_id + if processor_id is None or processor_id == 'None': + self.reset() + return '' + from_config = config.get(processor_id, {}).get('load_config', None) + if from_config is not None: + for k, v in from_config.items(): + self.load_config[k] = v + cls = config[processor_id]['class'] + log.debug(f'Control processor loading: id="{processor_id}" class={cls.__name__}') + debug(f'Control processor config={self.load_config}') + if 'DWPose' in processor_id: + det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' + if 'Tiny' == config['DWPose']['model']: + pose_config = 'config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py' + pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-tt_ucoco.pth' + elif 'Medium' == config['DWPose']['model']: + pose_config = 'config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py' + pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-mm_ucoco.pth' + elif 'Large' == config['DWPose']['model']: + pose_config = 'config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' + pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.pth' + else: + log.error(f'Control processor load failed: id="{processor_id}" error=unknown model type') + return f'Processor failed to load: {processor_id}' + self.model = cls(det_ckpt=det_ckpt, pose_config=pose_config, pose_ckpt=pose_ckpt, device="cpu") + elif 'SegmentAnything' in processor_id: + if 'Base' == config['SegmentAnything']['model']: + self.model = cls.from_pretrained(model_path = 'segments-arnaud/sam_vit_b', filename='sam_vit_b_01ec64.pth', model_type='vit_b', **self.load_config) + elif 'Large' == config['SegmentAnything']['model']: + self.model = cls.from_pretrained(model_path = 'segments-arnaud/sam_vit_l', filename='sam_vit_l_0b3195.pth', model_type='vit_l', **self.load_config) + else: + log.error(f'Control processor load failed: id="{processor_id}" error=unknown model type') + return f'Processor failed to load: {processor_id}' + elif config[processor_id].get('load_config', None) is not None: + self.model = cls.from_pretrained(**self.load_config) + elif config[processor_id]['checkpoint']: + self.model = cls.from_pretrained("lllyasviel/Annotators", **self.load_config) + else: + self.model = cls() # class instance only + t1 = time.time() + self.processor_id = processor_id + log.debug(f'Control processor loaded: id="{processor_id}" class={self.model.__class__.__name__} time={t1-t0:.2f}') + return f'Processor loaded: {processor_id}' + except Exception as e: + log.error(f'Control processor load failed: id="{processor_id}" error={e}') + display(e, 'Control processor load') + return f'Processor load filed: {processor_id}' + + def __call__(self, image_input: Image): + if self.override is not None: + image_input = self.override + image_process = image_input + if image_input is None: + log.error('Control processor: no input') + return image_process + if self.model is None: + # log.error('Control processor: model not loaded') + return image_process + if config[self.processor_id].get('dirty', False): + processor_id = self.processor_id + config[processor_id].pop('dirty') + self.reset() + self.load(processor_id) + try: + t0 = time.time() + kwargs = config.get(self.processor_id, {}).get('params', None) + if self.resize: + orig_size = image_input.size + image_resized = image_input.resize((512, 512)) + else: + image_resized = image_input + with torch.no_grad(): + image_process = self.model(image_resized, **kwargs) + if self.resize: + image_process = image_process.resize(orig_size, Image.Resampling.LANCZOS) + t1 = time.time() + log.debug(f'Control processor: id="{self.processor_id}" args={kwargs} time={t1-t0:.2f}') + except Exception as e: + log.error(f'Control processor failed: id="{self.processor_id}" error={e}') + display(e, 'Control processor') + return image_process + + def preview(self, image_input: Image): + return self.__call__(image_input) diff --git a/modules/control/run.py b/modules/control/run.py new file mode 100644 index 000000000..e4dc60dfe --- /dev/null +++ b/modules/control/run.py @@ -0,0 +1,507 @@ +import os +import time +import math +from typing import List, Union +import cv2 +import numpy as np +import diffusers +from PIL import Image +from modules.control import util +from modules.control import unit +from modules.control import processors +from modules.control.units import controlnet # lllyasviel ControlNet +from modules.control.units import xs # VisLearn ControlNet-XS +from modules.control.units import lite # Kohya ControlLLLite +from modules.control.units import t2iadapter # TencentARC T2I-Adapter +from modules.control.units import reference # ControlNet-Reference +from modules.control.units import ipadapter # IP-Adapter +from modules import devices, shared, errors, processing, images, sd_models, sd_samplers + + +debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +pipe = None +original_pipeline = None + + +class ControlProcessing(processing.StableDiffusionProcessingImg2Img): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.strength = None + self.adapter_conditioning_scale = None + self.adapter_conditioning_factor = None + self.guess_mode = None + self.controlnet_conditioning_scale = None + self.control_guidance_start = None + self.control_guidance_end = None + self.reference_attn = None + self.reference_adain = None + self.attention_auto_machine_weight = None + self.gn_auto_machine_weight = None + self.style_fidelity = None + self.ref_image = None + self.image = None + self.query_weight = None + self.adain_weight = None + self.adapter_conditioning_factor = 1.0 + self.attention = 'Attention' + self.fidelity = 0.5 + self.override = None + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): # abstract + pass + + +def restore_pipeline(): + global pipe # pylint: disable=global-statement + pipe = None + if original_pipeline is not None: + shared.sd_model = original_pipeline + debug(f'Control restored pipeline: class={shared.sd_model.__class__.__name__}') + devices.torch_gc() + + +def control_run(units: List[unit.Unit], inputs, inits, unit_type: str, is_generator: bool, input_type: int, + prompt, negative, styles, steps, sampler_index, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, + hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time, + denoising_strength, batch_count, batch_size, + video_skip_frames, video_type, video_duration, video_loop, video_pad, video_interpolate, + ip_adapter, ip_scale, ip_image, ip_type, + ): + global pipe, original_pipeline # pylint: disable=global-statement + debug(f'Control {unit_type}: input={inputs} init={inits} type={input_type}') + if inputs is None or (type(inputs) is list and len(inputs) == 0): + inputs = [None] + output_images: List[Image.Image] = [] # output images + active_process: List[processors.Processor] = [] # all active preprocessors + active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models + active_strength: List[float] = [] # strength factors for all active models + active_start: List[float] = [] # start step for all active models + active_end: List[float] = [] # end step for all active models + processed_image: Image.Image = None # last processed image + width = 8 * math.ceil(width / 8) + height = 8 * math.ceil(height / 8) + + p = ControlProcessing( + prompt = prompt, + negative_prompt = negative, + styles = styles, + steps = steps, + sampler_name = sd_samplers.samplers[sampler_index].name, + latent_sampler = sd_samplers.samplers[sampler_index].name, + seed = seed, + subseed = subseed, + subseed_strength = subseed_strength, + seed_resize_from_h = seed_resize_from_h, + seed_resize_from_w = seed_resize_from_w, + cfg_scale = cfg_scale, + clip_skip = clip_skip, + image_cfg_scale = image_cfg_scale, + diffusers_guidance_rescale = diffusers_guidance_rescale, + full_quality = full_quality, + restore_faces = restore_faces, + tiling = tiling, + hdr_clamp = hdr_clamp, + hdr_boundary = hdr_boundary, + hdr_threshold = hdr_threshold, + hdr_center = hdr_center, + hdr_channel_shift = hdr_channel_shift, + hdr_full_shift = hdr_full_shift, + hdr_maximize = hdr_maximize, + hdr_max_center = hdr_max_center, + hdr_max_boundry = hdr_max_boundry, + resize_mode = resize_mode if resize_name != 'None' else 0, + resize_name = resize_name, + scale_by = scale_by, + selected_scale_tab = selected_scale_tab, + denoising_strength = denoising_strength, + n_iter = batch_count, + batch_size = batch_size, + ) + processing.process_init(p) + + if resize_mode != 0 or inputs is None or inputs == [None]: + p.width = width # pylint: disable=attribute-defined-outside-init + p.height = height # pylint: disable=attribute-defined-outside-init + if selected_scale_tab == 1: + width = int(width * scale_by) + height = int(height * scale_by) + else: + del p.width + del p.height + + t0 = time.time() + for u in units: + if not u.enabled or u.type != unit_type: + continue + if unit_type == 'adapter' and u.adapter.model is not None: + active_process.append(u.process) + active_model.append(u.adapter) + active_strength.append(float(u.strength)) + p.adapter_conditioning_factor = u.factor + shared.log.debug(f'Control T2I-Adapter unit: process={u.process.processor_id} model={u.adapter.model_id} strength={u.strength} factor={u.factor}') + elif unit_type == 'controlnet' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + active_start.append(float(u.start)) + active_end.append(float(u.end)) + p.guess_mode = u.guess + shared.log.debug(f'Control ControlNet unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') + elif unit_type == 'xs' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + active_start.append(float(u.start)) + active_end.append(float(u.end)) + shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') + elif unit_type == 'lite' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + shared.log.debug(f'Control ControlNet-XS unit: process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') + elif unit_type == 'reference': + p.override = u.override + p.attention = u.attention + p.query_weight = float(u.query_weight) + p.adain_weight = float(u.adain_weight) + p.fidelity = u.fidelity + shared.log.debug('Control Reference unit') + else: + active_process.append(u.process) + # active_model.append(model) + active_strength.append(float(u.strength)) + p.ops.append('control') + + has_models = False + selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None + if unit_type == 'adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite': + if len(active_model) == 0: + selected_models = None + elif len(active_model) == 1: + selected_models = active_model[0].model if active_model[0].model is not None else None + p.extra_generation_params["Control model"] = (active_model[0].model_id or '') if active_model[0].model is not None else None + has_models = selected_models is not None + else: + selected_models = [m.model for m in active_model if m.model is not None] + p.extra_generation_params["Control model"] = ', '.join([(m.model_id or '') for m in active_model if m.model is not None]) + has_models = len(selected_models) > 0 + use_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength] + else: + pass + + debug(f'Control: run type={unit_type} models={has_models}') + if unit_type == 'adapter' and has_models: + p.extra_generation_params["Control mode"] = 'Adapter' + p.extra_generation_params["Control conditioning"] = use_conditioning + p.task_args['adapter_conditioning_scale'] = use_conditioning + instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: T2I-Adapter does not support separate init image') + elif unit_type == 'controlnet' and has_models: + p.extra_generation_params["Control mode"] = 'ControlNet' + p.extra_generation_params["Control conditioning"] = use_conditioning + p.task_args['controlnet_conditioning_scale'] = use_conditioning + p.task_args['control_guidance_start'] = active_start[0] if len(active_start) == 1 else list(active_start) + p.task_args['control_guidance_end'] = active_end[0] if len(active_end) == 1 else list(active_end) + p.task_args['guess_mode'] = p.guess_mode + instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model) + pipe = instance.pipeline + elif unit_type == 'xs' and has_models: + p.extra_generation_params["Control mode"] = 'ControlNet-XS' + p.extra_generation_params["Control conditioning"] = use_conditioning + p.controlnet_conditioning_scale = use_conditioning + p.control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start) + p.control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end) + instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlNet-XS does not support separate init image') + elif unit_type == 'lite' and has_models: + p.extra_generation_params["Control mode"] = 'ControlLLLite' + p.extra_generation_params["Control conditioning"] = use_conditioning + p.controlnet_conditioning_scale = use_conditioning + instance = lite.ControlLLitePipeline(shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlLLLite does not support separate init image') + elif unit_type == 'reference': + p.extra_generation_params["Control mode"] = 'Reference' + p.extra_generation_params["Control attention"] = p.attention + p.task_args['reference_attn'] = 'Attention' in p.attention + p.task_args['reference_adain'] = 'Adain' in p.attention + p.task_args['attention_auto_machine_weight'] = p.query_weight + p.task_args['gn_auto_machine_weight'] = p.adain_weight + p.task_args['style_fidelity'] = p.fidelity + instance = reference.ReferencePipeline(shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlNet-XS does not support separate init image') + else: # run in img2img mode + if len(active_strength) > 0: + p.strength = active_strength[0] + pipe = diffusers.AutoPipelineForImage2Image.from_pipe(shared.sd_model) # use set_diffuser_pipe + instance = None + + debug(f'Control pipeline: class={pipe.__class__} args={vars(p)}') + t1, t2, t3 = time.time(), 0, 0 + status = True + frame = None + video = None + output_filename = None + index = 0 + frames = 0 + + original_pipeline = shared.sd_model + if pipe is not None: + shared.sd_model = pipe + if not ((shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) or (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram)): + shared.sd_model.to(shared.device) + sd_models.copy_diffuser_options(shared.sd_model, original_pipeline) # copy options from original pipeline + sd_models.set_diffuser_options(shared.sd_model) + if ipadapter.apply_ip_adapter(shared.sd_model, p, ip_adapter, ip_scale, ip_image, reset=True): + original_pipeline.feature_extractor = shared.sd_model.feature_extractor + original_pipeline.image_encoder = shared.sd_model.image_encoder + + try: + with devices.inference_context(): + if isinstance(inputs, str): # only video, the rest is a list + if input_type == 2: # separate init image + if isinstance(inits, str) and inits != inputs: + shared.log.warning('Control: separate init video not support for video input') + input_type = 1 + try: + video = cv2.VideoCapture(inputs) + if not video.isOpened(): + msg = f'Control: video open failed: path={inputs}' + shared.log.error(msg) + restore_pipeline() + return msg + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(video.get(cv2.CAP_PROP_FPS)) + w, h = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + codec = util.decode_fourcc(video.get(cv2.CAP_PROP_FOURCC)) + status, frame = video.read() + shared.log.debug(f'Control: input video: path={inputs} frames={frames} fps={fps} size={w}x{h} codec={codec}') + except Exception as e: + msg = f'Control: video open failed: path={inputs} {e}' + shared.log.error(msg) + restore_pipeline() + return msg + + while status: + processed_image = None + if frame is not None: + inputs = [Image.fromarray(frame)] # cv2 to pil + for i, input_image in enumerate(inputs): + debug(f'Control Control image: {i + 1} of {len(inputs)}') + if shared.state.skipped: + shared.state.skipped = False + continue + if shared.state.interrupted: + shared.state.interrupted = False + restore_pipeline() + return 'Control interrupted' + # get input + if isinstance(input_image, str): + try: + input_image = Image.open(inputs[i]) + except Exception as e: + shared.log.error(f'Control: image open failed: path={inputs[i]} type=control error={e}') + continue + # match init input + if input_type == 1: + debug('Control Init image: same as control') + init_image = input_image + elif inits is None: + debug('Control Init image: none') + init_image = None + elif isinstance(inits[i], str): + debug(f'Control: init image: {inits[i]}') + try: + init_image = Image.open(inits[i]) + except Exception as e: + shared.log.error(f'Control: image open failed: path={inits[i]} type=init error={e}') + continue + else: + debug(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') + init_image = inits[i % len(inits)] + index += 1 + if video is not None and index % (video_skip_frames + 1) != 0: + continue + + # resize + if p.resize_mode != 0 and input_image is not None: + p.extra_generation_params["Control resize"] = f'{resize_time}: {resize_name}' + if p.resize_mode != 0 and input_image is not None and resize_time == 'Before': + debug(f'Control resize: image={input_image} width={width} height={height} mode={p.resize_mode} name={resize_name} sequence={resize_time}') + input_image = images.resize_image(p.resize_mode, input_image, width, height, resize_name) + + # process + if input_image is None: + p.image = None + processed_image = None + debug('Control: process=None image=None') + elif len(active_process) == 0 and unit_type == 'reference': + p.ref_image = p.override or input_image + p.task_args['ref_image'] = p.ref_image + debug(f'Control: process=None image={p.ref_image}') + if p.ref_image is None: + msg = 'Control: attempting reference mode but image is none' + shared.log.error(msg) + restore_pipeline() + return msg + processed_image = p.ref_image + elif len(active_process) == 1: + p.image = active_process[0](input_image) + p.task_args['image'] = p.image + p.extra_generation_params["Control process"] = active_process[0].processor_id + debug(f'Control: process={active_process[0].processor_id} image={p.image}') + if p.image is None: + msg = 'Control: attempting process but output is none' + shared.log.error(msg) + restore_pipeline() + return msg + processed_image = p.image + else: + if len(active_process) > 0: + p.image = [p(input_image) for p in active_process] # list[image] + else: + p.image = [input_image] + p.task_args['image'] = p.image + p.extra_generation_params["Control process"] = [p.processor_id for p in active_process] + debug(f'Control: process={[p.processor_id for p in active_process]} image={p.image}') + if any(img is None for img in p.image): + msg = 'Control: attempting process but output is none' + shared.log.error(msg) + restore_pipeline() + return msg + processed_image = [np.array(i) for i in p.image] + processed_image = util.blend(processed_image) # blend all processed images into one + processed_image = Image.fromarray(processed_image) + + if unit_type == 'controlnet' and input_type == 1: # Init image same as control + p.task_args['image'] = input_image + p.task_args['control_image'] = p.image + p.task_args['strength'] = p.denoising_strength + elif unit_type == 'controlnet' and input_type == 2: # Separate init image + p.task_args['control_image'] = p.image + p.task_args['strength'] = p.denoising_strength + if init_image is None: + shared.log.warning('Control: separate init image not provided') + p.task_args['image'] = input_image + else: + p.task_args['image'] = init_image + + if ip_type == 1 and ip_adapter != 'none': + p.task_args['ip_adapter_image'] = input_image + + if is_generator: + image_txt = f'{processed_image.width}x{processed_image.height}' if processed_image is not None else 'None' + msg = f'process | {index} of {frames if video is not None else len(inputs)} | {"Image" if video is None else "Frame"} {image_txt}' + debug(f'Control yield: {msg}') + yield (None, processed_image, f'Control {msg}') + t2 += time.time() - t2 + + # prepare pipeline + if hasattr(p, 'init_images'): + del p.init_images # control never uses init_image as-is + if pipe is not None: + if not has_models and (unit_type == 'controlnet' or unit_type == 'adapter' or unit_type == 'xs' or unit_type == 'lite'): # run in txt2img or img2img mode + if processed_image is not None: + p.init_images = [processed_image] + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) + else: + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) + elif unit_type == 'reference': + p.is_control = True + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) + else: # actual control + p.is_control = True + if 'control_image' in p.task_args: + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img + else: + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) + if unit_type == 'lite': + instance.apply(selected_models, p.image, use_conditioning) + + # pipeline + output = None + if pipe is not None: # run new pipeline + debug(f'Control exec pipeline: class={pipe.__class__}') + debug(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)}') + debug(f'Control exec pipeline: p={vars(p)}') + debug(f'Control exec pipeline: args={p.task_args}') + debug(f'Control exec pipeline: image={p.task_args.get("image", None)}') + processed: processing.Processed = processing.process_images(p) # run actual pipeline + output = processed.images if processed is not None else None + # output = pipe(**vars(p)).images # alternative direct pipe exec call + else: # blend all processed images and return + output = [processed_image] + t3 += time.time() - t3 + + # outputs + if output is not None and len(output) > 0: + output_image = output[0] + if output_image is not None: + # resize + if p.resize_mode != 0 and resize_time == 'After': + debug(f'Control resize: image={input_image} width={width} height={height} mode={p.resize_mode} name={resize_name} sequence={resize_time}') + output_image = images.resize_image(p.resize_mode, output_image, width, height, resize_name) + elif hasattr(p, 'width') and hasattr(p, 'height'): + output_image = output_image.resize((p.width, p.height), Image.Resampling.LANCZOS) + + output_images.append(output_image) + if is_generator: + image_txt = f'{output_image.width}x{output_image.height}' if output_image is not None else 'None' + if video is not None: + msg = f'Control output | {index} of {frames} skip {video_skip_frames} | Frame {image_txt}' + else: + msg = f'Control output | {index} of {len(inputs)} | Image {image_txt}' + yield (output_image, processed_image, msg) # result is control_output, proces_output + + if video is not None and frame is not None: + status, frame = video.read() + debug(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') + else: + status = False + + if video is not None: + video.release() + + shared.log.info(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}') + except Exception as e: + shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}') + errors.display(e, 'Control') + + shared.sd_model = original_pipeline + pipe = None + devices.torch_gc() + + if len(output_images) == 0: + output_images = None + image_txt = 'images=None' + elif len(output_images) == 1: + output_images = output_images[0] + image_txt = f'| Images 1 | Size {output_images.width}x{output_images.height}' if output_image is not None else 'None' + else: + image_txt = f'| Images {len(output_images)} | Size {output_images[0].width}x{output_images[0].height}' if output_image is not None else 'None' + + if video_type != 'None' and isinstance(output_images, list): + p.do_not_save_grid = True # pylint: disable=attribute-defined-outside-init + output_filename = images.save_video(p, filename=None, images=output_images, video_type=video_type, duration=video_duration, loop=video_loop, pad=video_pad, interpolate=video_interpolate, sync=True) + image_txt = f'| Frames {len(output_images)} | Size {output_images[0].width}x{output_images[0].height}' + + image_txt += f' | {util.dict2str(p.extra_generation_params)}' + if hasattr(instance, 'restore'): + instance.restore() + restore_pipeline() + debug(f'Control ready: {image_txt}') + if is_generator: + yield (output_images, processed_image, f'Control ready {image_txt}', output_filename) + else: + return (output_images, processed_image, f'Control ready {image_txt}', output_filename) diff --git a/modules/control/test.py b/modules/control/test.py new file mode 100644 index 000000000..8345997a0 --- /dev/null +++ b/modules/control/test.py @@ -0,0 +1,262 @@ +import math +from PIL import Image, ImageChops +from modules import shared, errors + + +def test_processors(image): + from modules.control import processors + if image is None: + shared.log.error('Image not loaded') + return None, None, None + from PIL import ImageDraw, ImageFont + images = [] + for processor_id in processors.list_models(): + if shared.state.interrupted: + continue + shared.log.info(f'Testing processor: {processor_id}') + processor = processors.Processor(processor_id) + output = image + if processor is None: + shared.log.error(f'Processor load failed: id="{processor_id}"') + processor_id = f'{processor_id} error' + else: + output = processor(image) + processor.reset() + if output.size != image.size: + output = output.resize(image.size, Image.Resampling.LANCZOS) + if output.mode != image.mode: + output = output.convert(image.mode) + shared.log.debug(f'Testing processor: input={image} mode={image.mode} output={output} mode={output.mode}') + diff = ImageChops.difference(image, output) + if not diff.getbbox(): + processor_id = f'{processor_id} null' + draw = ImageDraw.Draw(output) + font = ImageFont.truetype('DejaVuSansMono', 48) + draw.text((10, 10), processor_id, (0,0,0), font=font) + draw.text((8, 8), processor_id, (255,255,255), font=font) + images.append(output) + yield output, None, None, images + rows = round(math.sqrt(len(images))) + cols = math.ceil(len(images) / rows) + w, h = 256, 256 + size = (cols * w + cols, rows * h + rows) + grid = Image.new('RGB', size=size, color='black') + shared.log.info(f'Test processors: images={len(images)} grid={grid}') + for i, image in enumerate(images): + x = (i % cols * w) + (i % cols) + y = (i // cols * h) + (i // cols) + thumb = image.copy().convert('RGB') + thumb.thumbnail((w, h), Image.Resampling.HAMMING) + grid.paste(thumb, box=(x, y)) + yield None, grid, None, images + return None, grid, None, images # preview_process, output_image, output_video, output_gallery + + +def test_controlnets(prompt, negative, image): + from modules import devices, sd_models + from modules.control.units import controlnet + if image is None: + shared.log.error('Image not loaded') + return None, None, None + from PIL import ImageDraw, ImageFont + images = [] + for model_id in controlnet.list_models(): + if model_id is None: + model_id = 'None' + if shared.state.interrupted: + continue + output = image + if model_id != 'None': + controlnet = controlnet.ControlNet(model_id=model_id, device=devices.device, dtype=devices.dtype) + if controlnet is None: + shared.log.error(f'ControlNet load failed: id="{model_id}"') + continue + shared.log.info(f'Testing ControlNet: {model_id}') + pipe = controlnet.ControlNetPipeline(controlnet=controlnet.model, pipeline=shared.sd_model) + pipe.pipeline.to(device=devices.device, dtype=devices.dtype) + sd_models.set_diffuser_options(pipe) + try: + res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') + output = res.images[0] + except Exception as e: + errors.display(e, f'ControlNet {model_id} inference') + model_id = f'{model_id} error' + pipe.restore() + draw = ImageDraw.Draw(output) + font = ImageFont.truetype('DejaVuSansMono', 48) + draw.text((10, 10), model_id, (0,0,0), font=font) + draw.text((8, 8), model_id, (255,255,255), font=font) + images.append(output) + yield output, None, None, images + rows = round(math.sqrt(len(images))) + cols = math.ceil(len(images) / rows) + w, h = 256, 256 + size = (cols * w + cols, rows * h + rows) + grid = Image.new('RGB', size=size, color='black') + shared.log.info(f'Test ControlNets: images={len(images)} grid={grid}') + for i, image in enumerate(images): + x = (i % cols * w) + (i % cols) + y = (i // cols * h) + (i // cols) + thumb = image.copy().convert('RGB') + thumb.thumbnail((w, h), Image.Resampling.HAMMING) + grid.paste(thumb, box=(x, y)) + yield None, grid, None, images + return None, grid, None, images # preview_process, output_image, output_video, output_gallery + + +def test_adapters(prompt, negative, image): + from modules import devices, sd_models + from modules.control.units import t2iadapter + if image is None: + shared.log.error('Image not loaded') + return None, None, None + from PIL import ImageDraw, ImageFont + images = [] + for model_id in t2iadapter.list_models(): + if model_id is None: + model_id = 'None' + if shared.state.interrupted: + continue + output = image.copy() + if model_id != 'None': + adapter = t2iadapter.Adapter(model_id=model_id, device=devices.device, dtype=devices.dtype) + if adapter is None: + shared.log.error(f'Adapter load failed: id="{model_id}"') + continue + shared.log.info(f'Testing Adapter: {model_id}') + pipe = t2iadapter.AdapterPipeline(adapter=adapter.model, pipeline=shared.sd_model) + pipe.pipeline.to(device=devices.device, dtype=devices.dtype) + sd_models.set_diffuser_options(pipe) + image = image.convert('L') if 'Canny' in model_id or 'Sketch' in model_id else image.convert('RGB') + try: + res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') + output = res.images[0] + except Exception as e: + errors.display(e, f'Adapter {model_id} inference') + model_id = f'{model_id} error' + pipe.restore() + draw = ImageDraw.Draw(output) + font = ImageFont.truetype('DejaVuSansMono', 48) + draw.text((10, 10), model_id, (0,0,0), font=font) + draw.text((8, 8), model_id, (255,255,255), font=font) + images.append(output) + yield output, None, None, images + rows = round(math.sqrt(len(images))) + cols = math.ceil(len(images) / rows) + w, h = 256, 256 + size = (cols * w + cols, rows * h + rows) + grid = Image.new('RGB', size=size, color='black') + shared.log.info(f'Test Adapters: images={len(images)} grid={grid}') + for i, image in enumerate(images): + x = (i % cols * w) + (i % cols) + y = (i // cols * h) + (i // cols) + thumb = image.copy().convert('RGB') + thumb.thumbnail((w, h), Image.Resampling.HAMMING) + grid.paste(thumb, box=(x, y)) + yield None, grid, None, images + return None, grid, None, images # preview_process, output_image, output_video, output_gallery + + +def test_xs(prompt, negative, image): + from modules import devices, sd_models + from modules.control.units import xs + if image is None: + shared.log.error('Image not loaded') + return None, None, None + from PIL import ImageDraw, ImageFont + images = [] + for model_id in xs.list_models(): + if model_id is None: + model_id = 'None' + if shared.state.interrupted: + continue + output = image + if model_id != 'None': + xs = xs.ControlNetXS(model_id=model_id, device=devices.device, dtype=devices.dtype) + if xs is None: + shared.log.error(f'ControlNet-XS load failed: id="{model_id}"') + continue + shared.log.info(f'Testing ControlNet-XS: {model_id}') + pipe = xs.ControlNetXSPipeline(controlnet=xs.model, pipeline=shared.sd_model) + pipe.pipeline.to(device=devices.device, dtype=devices.dtype) + sd_models.set_diffuser_options(pipe) + try: + res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') + output = res.images[0] + except Exception as e: + errors.display(e, f'ControlNet-XS {model_id} inference') + model_id = f'{model_id} error' + pipe.restore() + draw = ImageDraw.Draw(output) + font = ImageFont.truetype('DejaVuSansMono', 48) + draw.text((10, 10), model_id, (0,0,0), font=font) + draw.text((8, 8), model_id, (255,255,255), font=font) + images.append(output) + yield output, None, None, images + rows = round(math.sqrt(len(images))) + cols = math.ceil(len(images) / rows) + w, h = 256, 256 + size = (cols * w + cols, rows * h + rows) + grid = Image.new('RGB', size=size, color='black') + shared.log.info(f'Test ControlNet-XS: images={len(images)} grid={grid}') + for i, image in enumerate(images): + x = (i % cols * w) + (i % cols) + y = (i // cols * h) + (i // cols) + thumb = image.copy().convert('RGB') + thumb.thumbnail((w, h), Image.Resampling.HAMMING) + grid.paste(thumb, box=(x, y)) + yield None, grid, None, images + return None, grid, None, images # preview_process, output_image, output_video, output_gallery + + +def test_lite(prompt, negative, image): + from modules import devices, sd_models + from modules.control.units import lite + if image is None: + shared.log.error('Image not loaded') + return None, None, None + from PIL import ImageDraw, ImageFont + images = [] + for model_id in lite.list_models(): + if model_id is None: + model_id = 'None' + if shared.state.interrupted: + continue + output = image + if model_id != 'None': + lite = lite.ControlLLLite(model_id=model_id, device=devices.device, dtype=devices.dtype) + if lite is None: + shared.log.error(f'Control-LLite load failed: id="{model_id}"') + continue + shared.log.info(f'Testing ControlNet-XS: {model_id}') + pipe = lite.ControlLLitePipeline(pipeline=shared.sd_model) + pipe.apply(controlnet=lite.model, image=image, conditioning=1.0) + pipe.pipeline.to(device=devices.device, dtype=devices.dtype) + sd_models.set_diffuser_options(pipe) + try: + res = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') + output = res.images[0] + except Exception as e: + errors.display(e, f'ControlNet-XS {model_id} inference') + model_id = f'{model_id} error' + pipe.restore() + draw = ImageDraw.Draw(output) + font = ImageFont.truetype('DejaVuSansMono', 48) + draw.text((10, 10), model_id, (0,0,0), font=font) + draw.text((8, 8), model_id, (255,255,255), font=font) + images.append(output) + yield output, None, None, images + rows = round(math.sqrt(len(images))) + cols = math.ceil(len(images) / rows) + w, h = 256, 256 + size = (cols * w + cols, rows * h + rows) + grid = Image.new('RGB', size=size, color='black') + shared.log.info(f'Test ControlNet-XS: images={len(images)} grid={grid}') + for i, image in enumerate(images): + x = (i % cols * w) + (i % cols) + y = (i // cols * h) + (i // cols) + thumb = image.copy().convert('RGB') + thumb.thumbnail((w, h), Image.Resampling.HAMMING) + grid.paste(thumb, box=(x, y)) + yield None, grid, None, images + return None, grid, None, images # preview_process, output_image, output_video, output_gallery diff --git a/modules/control/unit.py b/modules/control/unit.py new file mode 100644 index 000000000..1df54b81b --- /dev/null +++ b/modules/control/unit.py @@ -0,0 +1,163 @@ +from typing import Union +from PIL import Image +from modules.shared import log +from modules.control import processors +from modules.control.units import controlnet +from modules.control.units import xs +from modules.control.units import lite +from modules.control.units import t2iadapter +from modules.control.units import reference # pylint: disable=unused-import + + +default_device = None +default_dtype = None + + +class Unit(): # mashup of gradio controls and mapping to actual implementation classes + def __init__(self, + # values + enabled: bool = None, + strength: float = None, + unit_type: str = None, + start: float = 0, + end: float = 1, + # ui bindings + enabled_cb = None, + reset_btn = None, + process_id = None, + preview_btn = None, + model_id = None, + model_strength = None, + image_input = None, + preview_process = None, + image_upload = None, + control_start = None, + control_end = None, + result_txt = None, + extra_controls: list = [], # noqa B006 + ): + self.enabled = enabled or False + self.type = unit_type + self.strength = strength or 1.0 + self.start = start or 0 + self.end = end or 1 + self.start = min(self.start, self.end) + self.end = max(self.start, self.end) + # processor always exists, adapter and controlnet are optional + self.process: processors.Processor = processors.Processor() + self.adapter: t2iadapter.Adapter = None + self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None + # map to input image + self.input: Image = image_input + self.override: Image = None + # global settings but passed per-unit + self.factor = 1.0 + self.guess = False + self.start = 0 + self.end = 1 + # reference settings + self.attention = 'Attention' + self.fidelity = 0.5 + self.query_weight = 1.0 + self.adain_weight = 1.0 + + def reset(): + if self.process is not None: + self.process.reset() + if self.adapter is not None: + self.adapter.reset() + if self.controlnet is not None: + self.controlnet.reset() + self.override = None + return [True, 'None', 'None', 1.0] # reset ui values + + def enabled_change(val): + self.enabled = val + + def strength_change(val): + self.strength = val + + def control_change(start, end): + self.start = min(start, end) + self.end = max(start, end) + + def adapter_extra(c1): + self.factor = c1 + + def controlnet_extra(c1): + self.guess = c1 + + def controlnetxs_extra(_c1): + pass # gr.component passed directly to load method + + def reference_extra(c1, c2, c3, c4): + self.attention = c1 + self.fidelity = c2 + self.query_weight = c3 + self.adain_weight = c4 + + def upload_image(image_file): + try: + self.process.override = Image.open(image_file.name) + self.override = self.process.override + log.debug(f'Control process upload image: path="{image_file.name}" image={self.process.override}') + except Exception as e: + log.error(f'Control process upload image failed: path="{image_file.name}" error={e}') + + # actual init + if self.type == 'adapter': + self.adapter = t2iadapter.Adapter(device=default_device, dtype=default_dtype) + elif self.type == 'controlnet': + self.controlnet = controlnet.ControlNet(device=default_device, dtype=default_dtype) + elif self.type == 'xs': + self.controlnet = xs.ControlNetXS(device=default_device, dtype=default_dtype) + elif self.type == 'lite': + self.controlnet = lite.ControlLLLite(device=default_device, dtype=default_dtype) + elif self.type == 'reference': + pass + else: + log.error(f'Control unknown type: unit={unit_type}') + return + + # bind ui controls to properties if present + if self.type == 'adapter': + if model_id is not None: + model_id.change(fn=self.adapter.load, inputs=[model_id], outputs=[result_txt], show_progress=True) + if extra_controls is not None and len(extra_controls) > 0: + extra_controls[0].change(fn=adapter_extra, inputs=extra_controls) + elif self.type == 'controlnet': + if model_id is not None: + model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) + if extra_controls is not None and len(extra_controls) > 0: + extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls) + elif self.type == 'xs': + if model_id is not None: + model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress=True) + if extra_controls is not None and len(extra_controls) > 0: + extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) + elif self.type == 'lite': + if model_id is not None: + model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) + if extra_controls is not None and len(extra_controls) > 0: + extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) + elif self.type == 'reference': + if extra_controls is not None and len(extra_controls) > 0: + extra_controls[0].change(fn=reference_extra, inputs=extra_controls) + extra_controls[1].change(fn=reference_extra, inputs=extra_controls) + extra_controls[2].change(fn=reference_extra, inputs=extra_controls) + extra_controls[3].change(fn=reference_extra, inputs=extra_controls) + if enabled_cb is not None: + enabled_cb.change(fn=enabled_change, inputs=[enabled_cb]) + if model_strength is not None: + model_strength.change(fn=strength_change, inputs=[model_strength]) + if process_id is not None: + process_id.change(fn=self.process.load, inputs=[process_id], outputs=[result_txt], show_progress=True) + if reset_btn is not None: + reset_btn.click(fn=reset, inputs=[], outputs=[enabled_cb, model_id, process_id, model_strength]) + if preview_btn is not None: + preview_btn.click(fn=self.process.preview, inputs=[self.input], outputs=[preview_process]) # return list of images for gallery + if image_upload is not None: + image_upload.upload(fn=upload_image, inputs=[image_upload], outputs=[]) # return list of images for gallery + if control_start is not None and control_end is not None: + control_start.change(fn=control_change, inputs=[control_start, control_end]) + control_end.change(fn=control_change, inputs=[control_start, control_end]) diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py new file mode 100644 index 000000000..4b1cf0869 --- /dev/null +++ b/modules/control/units/controlnet.py @@ -0,0 +1,169 @@ +import os +import time +from typing import Union +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline +from modules.shared import log, opts +from modules import errors + + +what = 'ControlNet' +debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +predefined_sd15 = { + 'OpenPose': "lllyasviel/control_v11p_sd15_openpose", + 'Canny': "lllyasviel/control_v11p_sd15_canny", + 'MLDS': "lllyasviel/control_v11p_sd15_mlsd", + 'Scribble': "lllyasviel/control_v11p_sd15_scribble", + 'SoftEdge': "lllyasviel/control_v11p_sd15_softedge", + 'Segment': "lllyasviel/control_v11p_sd15_seg", + 'Depth': "lllyasviel/control_v11f1p_sd15_depth", + 'NormalBae': "lllyasviel/control_v11p_sd15_normalbae", + 'LineArt': "lllyasviel/control_v11p_sd15_lineart", + 'LineArt Anime': "lllyasviel/control_v11p_sd15s2_lineart_anime", + 'Shuffle': "lllyasviel/control_v11e_sd15_shuffle", + 'IP2P': "lllyasviel/control_v11e_sd15_ip2p", + 'HED': "lllyasviel/sd-controlnet-hed", + 'Tile': "lllyasviel/control_v11f1e_sd15_tile", + 'TemporalNet': "CiaraRowles/TemporalNet", +} +predefined_sdxl = { + 'Canny Small XL': 'diffusers/controlnet-canny-sdxl-1.0-small', + 'Canny Mid XL': 'diffusers/controlnet-canny-sdxl-1.0-mid', + 'Canny XL': 'diffusers/controlnet-canny-sdxl-1.0', + 'Depth Zoe XL': 'diffusers/controlnet-zoe-depth-sdxl-1.0', + 'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid', +} +models = {} +all_models = {} +all_models.update(predefined_sd15) +all_models.update(predefined_sdxl) +cache_dir = 'models/control/controlnet' + + +def find_models(): + path = os.path.join(opts.control_dir, 'controlnet') + files = os.listdir(path) + files = [f for f in files if f.endswith('.safetensors')] + downloaded_models = {} + for f in files: + basename = os.path.splitext(f)[0] + downloaded_models[basename] = os.path.join(path, f) + all_models.update(downloaded_models) + return downloaded_models + + +def list_models(refresh=False): + import modules.shared + global models # pylint: disable=global-statement + if not refresh and len(models) > 0: + return models + models = {} + if modules.shared.sd_model_type == 'none': + models = ['None'] + elif modules.shared.sd_model_type == 'sdxl': + models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) + elif modules.shared.sd_model_type == 'sd': + models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) + else: + log.warning(f'Control {what} model list failed: unknown model type') + models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) + debug(f'Control list {what}: path={cache_dir} models={models}') + return models + + +class ControlNet(): + def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): + self.model: ControlNetModel = None + self.model_id: str = model_id + self.device = device + self.dtype = dtype + self.load_config = { 'cache_dir': cache_dir } + if load_config is not None: + self.load_config.update(load_config) + if model_id is not None: + self.load() + + def reset(self): + if self.model is not None: + log.debug(f'Control {what} model unloaded') + self.model = None + self.model_id = None + + def load(self, model_id: str = None) -> str: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') + if model_path.endswith('.safetensors'): + self.model = ControlNetModel.from_single_file(model_path, **self.load_config) + else: + self.model = ControlNetModel.from_pretrained(model_path, **self.load_config) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' + + +class ControlNetPipeline(): + def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): + t0 = time.time() + self.orig_pipeline = pipeline + self.pipeline = None + if pipeline is None: + log.error('Control model pipeline: model not loaded') + return + elif isinstance(pipeline, StableDiffusionXLPipeline): + self.pipeline = StableDiffusionXLControlNetPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer, + tokenizer_2=pipeline.tokenizer_2, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + controlnet=controlnet, # can be a list + ).to(pipeline.device) + elif isinstance(pipeline, StableDiffusionPipeline): + self.pipeline = StableDiffusionControlNetPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + requires_safety_checker=False, + safety_checker=None, + controlnet=controlnet, # can be a list + ).to(pipeline.device) + else: + log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') + return + if dtype is not None and self.pipeline is not None: + self.pipeline = self.pipeline.to(dtype) + t1 = time.time() + if self.pipeline is not None: + log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + else: + log.error(f'Control {what} pipeline: not initialized') + + def restore(self): + self.pipeline = None + return self.orig_pipeline diff --git a/modules/control/units/ipadapter.py b/modules/control/units/ipadapter.py new file mode 100644 index 000000000..ed12ca63c --- /dev/null +++ b/modules/control/units/ipadapter.py @@ -0,0 +1,90 @@ +import time +from PIL import Image +from modules import shared, processing, devices + + +image_encoder = None +image_encoder_type = None +loaded = None +ADAPTERS = [ + 'none', + 'ip-adapter_sd15', + 'ip-adapter_sd15_light', + 'ip-adapter-plus_sd15', + 'ip-adapter-plus-face_sd15', + 'ip-adapter-full-face_sd15', + # 'models/ip-adapter_sd15_vit-G', # RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x1024 and 1280x3072) + 'ip-adapter_sdxl', + # 'sdxl_models/ip-adapter_sdxl_vit-h', + # 'sdxl_models/ip-adapter-plus_sdxl_vit-h', + # 'sdxl_models/ip-adapter-plus-face_sdxl_vit-h', +] + + +def apply_ip_adapter(pipe, p: processing.StableDiffusionProcessing, adapter, scale, image, reset=False): # pylint: disable=arguments-differ + from transformers import CLIPVisionModelWithProjection + # overrides + if hasattr(p, 'ip_adapter_name'): + adapter = p.ip_adapter_name + if hasattr(p, 'ip_adapter_scale'): + scale = p.ip_adapter_scale + if hasattr(p, 'ip_adapter_image'): + image = p.ip_adapter_image + # init code + global loaded, image_encoder, image_encoder_type # pylint: disable=global-statement + if pipe is None: + return + if shared.backend != shared.Backend.DIFFUSERS: + shared.log.warning('IP adapter: not in diffusers mode') + return False + if adapter == 'none': + if hasattr(pipe, 'set_ip_adapter_scale'): + pipe.set_ip_adapter_scale(0) + if loaded is not None: + shared.log.debug('IP adapter: unload attention processor') + pipe.unet.set_default_attn_processor() + pipe.unet.config.encoder_hid_dim_type = None + loaded = None + return False + if image is None: + image = Image.new('RGB', (512, 512), (0, 0, 0)) + if not hasattr(pipe, 'load_ip_adapter'): + shared.log.error(f'IP adapter: pipeline not supported: {pipe.__class__.__name__}') + return False + if getattr(pipe, 'image_encoder', None) is None or getattr(pipe, 'image_encoder', None) == (None, None): + if shared.sd_model_type == 'sd': + subfolder = 'models/image_encoder' + elif shared.sd_model_type == 'sdxl': + subfolder = 'sdxl_models/image_encoder' + else: + shared.log.error(f'IP adapter: unsupported model type: {shared.sd_model_type}') + return False + if image_encoder is None or image_encoder_type != shared.sd_model_type: + try: + image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder=subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True).to(devices.device) + image_encoder_type = shared.sd_model_type + except Exception as e: + shared.log.error(f'IP adapter: failed to load image encoder: {e}') + return False + pipe.image_encoder = image_encoder + + # main code + subfolder = 'models' if 'sd15' in adapter else 'sdxl_models' + if adapter != loaded or getattr(pipe.unet.config, 'encoder_hid_dim_type', None) is None or reset: + t0 = time.time() + if loaded is not None: + # shared.log.debug('IP adapter: reset attention processor') + pipe.unet.set_default_attn_processor() + loaded = None + else: + shared.log.debug('IP adapter: load attention processor') + pipe.load_ip_adapter("h94/IP-Adapter", subfolder=subfolder, weight_name=f'{adapter}.safetensors') + t1 = time.time() + shared.log.info(f'IP adapter load: adapter="{adapter}" scale={scale} image={image} time={t1-t0:.2f}') + loaded = adapter + else: + shared.log.debug(f'IP adapter cache: adapter="{adapter}" scale={scale} image={image}') + pipe.set_ip_adapter_scale(scale) + p.task_args['ip_adapter_image'] = p.batch_size * [image] + p.extra_generation_params["IP Adapter"] = f'{adapter}:{scale}' + return True diff --git a/modules/control/units/lite.py b/modules/control/units/lite.py new file mode 100644 index 000000000..9796f77f1 --- /dev/null +++ b/modules/control/units/lite.py @@ -0,0 +1,135 @@ +import os +import time +from typing import Union +import numpy as np +from PIL import Image +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from modules.shared import log, opts +from modules import errors +from modules.control.units.lite_model import ControlNetLLLite + + +what = 'ControlLLLite' +debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +predefined_sd15 = { +} +predefined_sdxl = { + 'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny', + 'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime', + 'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime', + 'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta', + 'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime', + 'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2', +} +models = {} +all_models = {} +all_models.update(predefined_sd15) +all_models.update(predefined_sdxl) +cache_dir = 'models/control/lite' + + +def find_models(): + path = os.path.join(opts.control_dir, 'lite') + files = os.listdir(path) + files = [f for f in files if f.endswith('.safetensors')] + downloaded_models = {} + for f in files: + basename = os.path.splitext(f)[0] + downloaded_models[basename] = os.path.join(path, f) + all_models.update(downloaded_models) + return downloaded_models + + +def list_models(refresh=False): + import modules.shared + global models # pylint: disable=global-statement + if not refresh and len(models) > 0: + return models + models = {} + if modules.shared.sd_model_type == 'none': + models = ['None'] + elif modules.shared.sd_model_type == 'sdxl': + models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) + elif modules.shared.sd_model_type == 'sd': + models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) + else: + log.warning(f'Control {what} model list failed: unknown model type') + models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) + debug(f'Control list {what}: path={cache_dir} models={models}') + return models + + +class ControlLLLite(): + def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): + self.model: ControlNetLLLite = None + self.model_id: str = model_id + self.device = device + self.dtype = dtype + self.load_config = { 'cache_dir': cache_dir } + if load_config is not None: + self.load_config.update(load_config) + if model_id is not None: + self.load() + + def reset(self): + if self.model is not None: + log.debug(f'Control {what} model unloaded') + self.model = None + self.model_id = None + + def load(self, model_id: str = None) -> str: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') + if model_path.endswith('.safetensors'): + self.model = ControlNetLLLite(model_path) + else: + import huggingface_hub as hf + folder, filename = os.path.split(model_path) + model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir) + self.model = ControlNetLLLite(model_path) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' + + +class ControlLLitePipeline(): + def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]): + self.pipeline = pipeline + self.nets = [] + + def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning): + if image is None: + return + self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet + debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}') + weight = [conditioning] if isinstance(conditioning, float) else conditioning + images = [image] if isinstance(image, Image.Image) else image + images = [i.convert('RGB') for i in images] + for i, cn in enumerate(self.nets): + cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)]) + + def restore(self): + from modules.control.units.lite_model import clear_all_lllite + clear_all_lllite() + self.nets = [] diff --git a/modules/control/units/lite_model.py b/modules/control/units/lite_model.py new file mode 100644 index 000000000..991f3c31a --- /dev/null +++ b/modules/control/units/lite_model.py @@ -0,0 +1,202 @@ +# Credits: +# + +import re +import torch +from safetensors.torch import load_file + + +all_hack = {} + + +class LLLiteModule(torch.nn.Module): + def __init__( + self, + name: str, + is_conv2d: bool, + in_dim: int, + depth: int, + cond_emb_dim: int, + mlp_dim: int, + ): + super().__init__() + self.name = name + self.is_conv2d = is_conv2d + self.is_first = False + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + self.conditioning1 = torch.nn.Sequential(*modules) + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) + else: + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) + self.depth = depth + self.cond_image = None + self.cond_emb = None + + def set_cond_image(self, cond_image): + self.cond_image = cond_image + self.cond_emb = None + + def forward(self, x): + if self.cond_emb is None: + # print(f"cond_emb is None, {self.name}") + cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) + # if blk_shape is not None: + # b, c, h, w = blk_shape + # cx = torch.nn.functional.interpolate(cx, (h, w), mode="nearest-exact") + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx + cx = self.cond_emb + + # uncond/condでxはバッチサイズが2倍 + if x.shape[0] != cx.shape[0]: + if self.is_conv2d: + cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1) + else: + # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0]) + cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1) + + cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) + cx = self.up(cx) + return cx + + +def clear_all_lllite(): + global all_hack # pylint: disable=global-statement + for k, v in all_hack.items(): + k.forward = v + k.lllite_list = [] + all_hack = {} + return + + +class ControlNetLLLite(torch.nn.Module): # pylint: disable=abstract-method + def __init__(self, path: str): + super().__init__() + module_weights = {} + try: + state_dict = load_file(path) + except Exception as e: + raise RuntimeError(f"Failed to load {path}") from e + for key, value in state_dict.items(): + fragments = key.split(".") + module_name = fragments[0] + weight_name = ".".join(fragments[1:]) + if module_name not in module_weights: + module_weights[module_name] = {} + module_weights[module_name][weight_name] = value + modules = {} + for module_name, weights in module_weights.items(): + if "conditioning1.4.weight" in weights: + depth = 3 + elif weights["conditioning1.2.weight"].shape[-1] == 4: + depth = 2 + else: + depth = 1 + + module = LLLiteModule( + name=module_name, + is_conv2d=weights["down.0.weight"].ndim == 4, + in_dim=weights["down.0.weight"].shape[1], + depth=depth, + cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, + mlp_dim=weights["down.0.weight"].shape[0], + ) + # info = module.load_state_dict(weights) + modules[module_name] = module + setattr(self, module_name, module) + if len(modules) == 1: + module.is_first = True + + self.modules = modules + return + + @torch.no_grad() + def apply(self, pipe, cond, weight): # pylint: disable=arguments-differ + map_down_lllite_to_unet = {4: (1, 0), 5: (1, 1), 7: (2, 0), 8: (2, 1)} + model = pipe.unet + if type(cond) != torch.Tensor: + cond = torch.tensor(cond) + cond = cond/255 # 0-255 -> 0-1 + cond_image = cond.unsqueeze(dim=0).permute(0, 3, 1, 2) # h,w,c -> b,c,h,w + cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-1 + + for module in self.modules.values(): + module.set_cond_image(cond_image) + for k, v in self.modules.items(): + k = k.replace('middle_block', 'middle_blocks_0') + match = re.match("lllite_unet_(.*)_blocks_(.*)_1_transformer_blocks_(.*)_(.*)_to_(.*)", k, re.M | re.I) + assert match, 'Failed to load ControlLLLite!' + root = match.group(1) + block = match.group(2) + block_number = match.group(3) + attn_name = match.group(4) + proj_name = match.group(5) + if root == 'input': + mapped_block, mapped_number = map_down_lllite_to_unet[int(block)] + b = model.down_blocks[mapped_block].attentions[int(mapped_number)].transformer_blocks[int(block_number)] + elif root == 'output': + # TODO: Map up unet blocks to lite blocks + print(f'Not implemented: {root}') + else: + b = model.mid_block.attentions[0].transformer_blocks[int(block_number)] + b = getattr(b, attn_name, None) + assert b is not None, 'Failed to load ControlLLLite!' + b = getattr(b, 'to_' + proj_name, None) + assert b is not None, 'Failed to load ControlLLLite!' + if not hasattr(b, 'lllite_list'): + b.lllite_list = [] + if len(b.lllite_list) == 0: + all_hack[b] = b.forward + b.forward = self.get_hacked_forward(original_forward=b.forward, model=model, blk=b) + b.lllite_list.append((weight, v)) + return + + def get_hacked_forward(self, original_forward, model, blk): + @torch.no_grad() + def forward(x, **kwargs): + hack = 0 + for weight, module in blk.lllite_list: + module.to(x.device) + module.to(x.dtype) + hack = hack + module(x) * weight + x = x + hack + return original_forward(x, **kwargs) + return forward diff --git a/modules/control/units/reference.py b/modules/control/units/reference.py new file mode 100644 index 000000000..ccc3b2277 --- /dev/null +++ b/modules/control/units/reference.py @@ -0,0 +1,58 @@ +import time +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from modules.control.proc.reference_sd15 import StableDiffusionReferencePipeline +from modules.control.proc.reference_sdxl import StableDiffusionXLReferencePipeline +from modules.shared import log + + +what = 'Reference' + + +def list_models(): + return ['Reference'] + + +class ReferencePipeline(): + def __init__(self, pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline, dtype = None): + t0 = time.time() + self.orig_pipeline = pipeline + self.pipeline = None + if pipeline is None: + log.error(f'Control {what} model pipeline: model not loaded') + return + if isinstance(pipeline, StableDiffusionXLPipeline): + self.pipeline = StableDiffusionXLReferencePipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer, + tokenizer_2=pipeline.tokenizer_2, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + ).to(pipeline.device) + elif isinstance(pipeline, StableDiffusionPipeline): + self.pipeline = StableDiffusionReferencePipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + requires_safety_checker=False, + safety_checker=None, + ).to(pipeline.device) + else: + log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') + return + if dtype is not None and self.pipeline is not None: + self.pipeline = self.pipeline.to(dtype) + t1 = time.time() + if self.pipeline is not None: + log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + else: + log.error(f'Control {what} pipeline: not initialized') + + def restore(self): + self.pipeline = None + return self.orig_pipeline diff --git a/modules/control/units/t2iadapter.py b/modules/control/units/t2iadapter.py new file mode 100644 index 000000000..39b05c9dd --- /dev/null +++ b/modules/control/units/t2iadapter.py @@ -0,0 +1,157 @@ +import os +import time +from typing import Union +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import +from modules.shared import log +from modules import errors + + +what = 'T2I-Adapter' +debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +predefined_sd15 = { + 'Segment': 'TencentARC/t2iadapter_seg_sd14v1', + 'Zoe Depth': 'TencentARC/t2iadapter_zoedepth_sd15v1', + 'OpenPose': 'TencentARC/t2iadapter_openpose_sd14v1', + 'KeyPose': 'TencentARC/t2iadapter_keypose_sd14v1', + 'Color': 'TencentARC/t2iadapter_color_sd14v1', + 'Depth v1': 'TencentARC/t2iadapter_depth_sd14v1', + 'Depth v2': 'TencentARC/t2iadapter_depth_sd15v2', + 'Canny v1': 'TencentARC/t2iadapter_canny_sd14v1', + 'Canny v2': 'TencentARC/t2iadapter_canny_sd15v2', + 'Sketch v1': 'TencentARC/t2iadapter_sketch_sd14v1', + 'Sketch v2': 'TencentARC/t2iadapter_sketch_sd15v2', +} +predefined_sdxl = { + 'Canny XL': 'TencentARC/t2i-adapter-canny-sdxl-1.0', + 'LineArt XL': 'TencentARC/t2i-adapter-lineart-sdxl-1.0', + 'Sketch XL': 'TencentARC/t2i-adapter-sketch-sdxl-1.0', + 'Zoe Depth XL': 'TencentARC/t2i-adapter-depth-zoe-sdxl-1.0', + 'OpenPose XL': 'TencentARC/t2i-adapter-openpose-sdxl-1.0', + 'Midas Depth XL': 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0', +} +models = {} +all_models = {} +all_models.update(predefined_sd15) +all_models.update(predefined_sdxl) +cache_dir = 'models/control/adapter' + + +def list_models(refresh=False): + import modules.shared + global models # pylint: disable=global-statement + if not refresh and len(models) > 0: + return models + models = {} + if modules.shared.sd_model_type == 'none': + models = ['None'] + elif modules.shared.sd_model_type == 'sdxl': + models = ['None'] + sorted(predefined_sdxl) + elif modules.shared.sd_model_type == 'sd': + models = ['None'] + sorted(predefined_sd15) + else: + log.warning(f'Control {what} model list failed: unknown model type') + models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl)) + debug(f'Control list {what}: path={cache_dir} models={models}') + return models + + +class AdapterModel(T2IAdapter): + pass + + +class Adapter(): + def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): + self.model: AdapterModel = None + self.model_id: str = model_id + self.device = device + self.dtype = dtype + self.load_config = { 'cache_dir': cache_dir } + if load_config is not None: + self.load_config.update(load_config) + if model_id is not None: + self.load() + + def reset(self): + if self.model is not None: + log.debug(f'Control {what} model unloaded') + self.model = None + self.model_id = None + + def load(self, model_id: str = None) -> str: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + model_path = all_models[model_id] + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') + self.model = T2IAdapter.from_pretrained(model_path, **self.load_config) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' + + +class AdapterPipeline(): + def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): + t0 = time.time() + self.orig_pipeline = pipeline + self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None + if pipeline is None: + log.error(f'Control {what} pipeline: model not loaded') + return + if isinstance(adapter, list) and len(adapter) > 1: # TODO use MultiAdapter + adapter = MultiAdapter(adapter) + adapter.to(device=pipeline.device, dtype=pipeline.dtype) + if isinstance(pipeline, StableDiffusionXLPipeline): + self.pipeline = StableDiffusionXLAdapterPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer, + tokenizer_2=pipeline.tokenizer_2, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + adapter=adapter, + ).to(pipeline.device) + elif isinstance(pipeline, StableDiffusionPipeline): + self.pipeline = StableDiffusionAdapterPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + requires_safety_checker=False, + safety_checker=None, + adapter=adapter, + ).to(pipeline.device) + else: + log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') + return + if dtype is not None and self.pipeline is not None: + self.pipeline.dtype = dtype + t1 = time.time() + if self.pipeline is not None: + log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + else: + log.error(f'Control {what} pipeline: not initialized') + + + def restore(self): + self.pipeline = None + return self.orig_pipeline diff --git a/modules/control/units/xs.py b/modules/control/units/xs.py new file mode 100644 index 000000000..b0dc84659 --- /dev/null +++ b/modules/control/units/xs.py @@ -0,0 +1,154 @@ +import os +import time +from typing import Union +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from modules.shared import log, opts +from modules import errors +from modules.control.units.xs_model import ControlNetXSModel +from modules.control.units.xs_pipe import StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline + + +what = 'ControlNet-XS' +debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') +predefined_sd15 = { +} +predefined_sdxl = { + 'Canny': 'UmerHA/ConrolNetXS-SDXL-canny', + 'Depth': 'UmerHA/ConrolNetXS-SDXL-depth', +} +models = {} +all_models = {} +all_models.update(predefined_sd15) +all_models.update(predefined_sdxl) +cache_dir = 'models/control/xs' + + +def find_models(): + path = os.path.join(opts.control_dir, 'xs') + files = os.listdir(path) + files = [f for f in files if f.endswith('.safetensors')] + downloaded_models = {} + for f in files: + basename = os.path.splitext(f)[0] + downloaded_models[basename] = os.path.join(path, f) + all_models.update(downloaded_models) + return downloaded_models + + +def list_models(refresh=False): + global models # pylint: disable=global-statement + import modules.shared + if not refresh and len(models) > 0: + return models + models = {} + if modules.shared.sd_model_type == 'none': + models = ['None'] + elif modules.shared.sd_model_type == 'sdxl': + models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) + elif modules.shared.sd_model_type == 'sd': + models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) + else: + log.error(f'Control {what} model list failed: unknown model type') + models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) + debug(f'Control list {what}: path={cache_dir} models={models}') + return models + + +class ControlNetXS(): + def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): + self.model: ControlNetXSModel = None + self.model_id: str = model_id + self.device = device + self.dtype = dtype + self.load_config = { 'cache_dir': cache_dir, 'learn_embedding': True } + if load_config is not None: + self.load_config.update(load_config) + if model_id is not None: + self.load() + + def reset(self): + if self.model is not None: + log.debug(f'Control {what} model unloaded') + self.model = None + self.model_id = None + + def load(self, model_id: str = None, time_embedding_mix: float = 0.0) -> str: + try: + t0 = time.time() + model_id = model_id or self.model_id + if model_id is None or model_id == 'None': + self.reset() + return + model_path = all_models[model_id] + if model_path == '': + return + if model_path is None: + log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') + return + self.load_config['time_embedding_mix'] = time_embedding_mix + log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') + if model_path.endswith('.safetensors'): + self.model = ControlNetXSModel.from_single_file(model_path, **self.load_config) + else: + self.model = ControlNetXSModel.from_pretrained(model_path, **self.load_config) + if self.device is not None: + self.model.to(self.device) + if self.dtype is not None: + self.model.to(self.dtype) + t1 = time.time() + self.model_id = model_id + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + return f'{what} loaded model: {model_id}' + except Exception as e: + log.error(f'Control {what} model load failed: id="{model_id}" error={e}') + errors.display(e, f'Control {what} load') + return f'{what} failed to load model: {model_id}' + + +class ControlNetXSPipeline(): + def __init__(self, controlnet: Union[ControlNetXSModel, list[ControlNetXSModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): + t0 = time.time() + self.orig_pipeline = pipeline + self.pipeline = None + if pipeline is None: + log.error(f'Control {what} pipeline: model not loaded') + return + if isinstance(pipeline, StableDiffusionXLPipeline): + self.pipeline = StableDiffusionXLControlNetXSPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=pipeline.tokenizer, + tokenizer_2=pipeline.tokenizer_2, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + # feature_extractor=getattr(pipeline, 'feature_extractor', None), + controlnet=controlnet, # can be a list + ).to(pipeline.device) + elif isinstance(pipeline, StableDiffusionPipeline): + self.pipeline = StableDiffusionControlNetXSPipeline( + vae=pipeline.vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + unet=pipeline.unet, + scheduler=pipeline.scheduler, + feature_extractor=getattr(pipeline, 'feature_extractor', None), + requires_safety_checker=False, + safety_checker=None, + controlnet=controlnet, # can be a list + ).to(pipeline.device) + else: + log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') + return + if dtype is not None and self.pipeline is not None: + self.pipeline = self.pipeline.to(dtype) + t1 = time.time() + if self.pipeline is not None: + log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + else: + log.error(f'Control {what} pipeline: not initialized') + + def restore(self): + self.pipeline = None + return self.orig_pipeline diff --git a/modules/control/units/xs_model.py b/modules/control/units/xs_model.py new file mode 100644 index 000000000..c6419b44d --- /dev/null +++ b/modules/control/units/xs_model.py @@ -0,0 +1,1016 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.normalization import GroupNorm + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.lora import LoRACompatibleConv +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UpBlock2D, + Upsample2D, +) +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`ControlNetXSModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model + output, but is already the final output. + """ + + sample: torch.FloatTensor = None + + +# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A ControlNet-XS model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation + of [`UNet2DConditionModel`] for them. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + time_embedding_input_dim (`int`, defaults to 320): + Dimension of input into time embedding. Needs to be same as in the base model. + time_embedding_dim (`int`, defaults to 1280): + Dimension of output from time embedding. Needs to be same as in the base model. + learn_embedding (`bool`, defaults to `False`): + Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of + the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the + control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. + base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): + Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + """ + + @classmethod + def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): + """ + Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). + + Parameters: + base_model (`UNet2DConditionModel`): + Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. + is_sdxl (`bool`, defaults to `True`): + Whether passed `base_model` is a StableDiffusion-XL model. + """ + + def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): + """ + Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). + The original ControlNet-XS model, however, define the number of attention heads. + That's why compute the dimensions needed to get the correct number of attention heads. + """ + block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] + dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] + return dim_attn_heads + + if is_sdxl: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=0.95, + learn_embedding=True, + size_ratio=0.1, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), + ) + else: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=1.0, + learn_embedding=True, + size_ratio=0.0125, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), + ) + + @classmethod + def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): + """To create correctly sized connections between base and control model, we need to know + the input and output channels of each subblock. + + Parameters: + unet (`UNet2DConditionModel`): + Unet of which the subblock channels sizes are to be gathered. + base_or_control (`str`): + Needs to be either "base" or "control". If "base", decoder is also considered. + """ + if base_or_control not in ["base", "control"]: + raise ValueError("`base_or_control` needs to be either `base` or `control`") + + channel_sizes = {"down": [], "mid": [], "up": []} + + # input convolution + channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) + + # encoder blocks + for module in unet.down_blocks: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + for r in module.resnets: + channel_sizes["down"].append((r.in_channels, r.out_channels)) + if module.downsamplers: + channel_sizes["down"].append( + (module.downsamplers[0].channels, module.downsamplers[0].out_channels) + ) + else: + raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") + + # middle block + channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) + + # decoder blocks + if base_or_control == "base": + for module in unet.up_blocks: + if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): + for r in module.resnets: + channel_sizes["up"].append((r.in_channels, r.out_channels)) + else: + raise ValueError( + f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." + ) + + return channel_sizes + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + time_embedding_input_dim: int = 320, + time_embedding_dim: int = 1280, + time_embedding_mix: float = 1.0, + learn_embedding: bool = False, + base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { + "down": [ + (4, 320), + (320, 320), + (320, 320), + (320, 320), + (320, 640), + (640, 640), + (640, 640), + (640, 1280), + (1280, 1280), + ], + "mid": [(1280, 1280)], + "up": [ + (2560, 1280), + (2560, 1280), + (1920, 1280), + (1920, 640), + (1280, 640), + (960, 640), + (960, 320), + (640, 320), + (640, 320), + ], + }, + sample_size: Optional[int] = None, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + upcast_attention: bool = False, + ): + super().__init__() + + # 1 - Create control unet + self.control_model = UNet2DConditionModel( + sample_size=sample_size, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + time_embedding_dim=time_embedding_dim, + ) + + # 2 - Do model surgery on control model + # 2.1 - Allow to use the same time information as the base model + adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) + + # 2.2 - Allow for information infusion from base model + + # We concat the output of each base encoder subblocks to the input of the next control encoder subblock + # (We ignore the 1st element, as it represents the `conv_in`.) + extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] + it_extra_input_channels = iter(extra_input_channels) + + for b, block in enumerate(self.control_model.down_blocks): + for r in range(len(block.resnets)): + increase_block_input_in_encoder_resnet( + self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) + ) + + if block.downsamplers: + increase_block_input_in_encoder_downsampler( + self.control_model, block_no=b, by=next(it_extra_input_channels) + ) + + increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) + + # 2.3 - Make group norms work with modified channel sizes + adjust_group_norms(self.control_model) + + # 3 - Gather Channel Sizes + self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") + self.ch_inout_base = base_model_channel_sizes + + # 4 - Build connections between base and control model + self.down_zero_convs_out = nn.ModuleList([]) + self.down_zero_convs_in = nn.ModuleList([]) + self.middle_block_out = nn.ModuleList([]) + self.middle_block_in = nn.ModuleList([]) + self.up_zero_convs_out = nn.ModuleList([]) + self.up_zero_convs_in = nn.ModuleList([]) + + for ch_io_base in self.ch_inout_base["down"]: + self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) + for i in range(len(self.ch_inout_ctrl["down"])): + self.down_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) + ) + + self.middle_block_out = self._make_zero_conv( + self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] + ) + + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) + ) + for i in range(1, len(self.ch_inout_ctrl["down"])): + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # In the mininal implementation setting, we only need the control model up to the mid block + del self.control_model.up_blocks + del self.control_model.conv_norm_out + del self.control_model.conv_out + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + learn_embedding: bool = False, + time_embedding_mix: float = 1.0, + block_out_channels: Optional[Tuple[int]] = None, + size_ratio: Optional[float] = None, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + norm_num_groups: Optional[int] = None, + ): + r""" + Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + learn_embedding (`bool`, defaults to `False`): + Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation + of the time embeddings of the control and base model with interpolation parameter + `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. + block_out_channels (`Tuple[int]`, *optional*): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + size_ratio (float, *optional*): + When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + Either this or `block_out_channels` must be given. + num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + norm_num_groups (int, *optional*, defaults to `None`): + The number of groups to use for the normalization of the control unet. If `None`, + `int(unet.config.norm_num_groups * size_ratio)` is taken. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + ) + + # Create model + if block_out_channels is None: + block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] + + # Check that attention heads and group norms match channel sizes + # - attention heads + def attn_heads_match_channel_sizes(attn_heads, channel_sizes): + if isinstance(attn_heads, (tuple, list)): + return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) + else: + return all(c % attn_heads == 0 for c in channel_sizes) + + num_attention_heads = num_attention_heads or unet.config.attention_head_dim + if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): + raise ValueError( + f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." + ) + + # - group norms + def group_norms_match_channel_sizes(num_groups, channel_sizes): + return all(c % num_groups == 0 for c in channel_sizes) + + if norm_num_groups is None: + if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): + norm_num_groups = unet.config.norm_num_groups + else: + norm_num_groups = min(block_out_channels) + + if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): + print( + f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." + ) + else: + raise ValueError( + f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." + ) + + def get_time_emb_input_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_1.in_features + + def get_time_emb_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_2.out_features + + # Clone params from base unet if + # (i) it's required to build SD or SDXL, and + # (ii) it's not used for the time embedding (as time embedding of control model is never used), and + # (iii) it's not set further below anyway + to_keep = [ + "cross_attention_dim", + "down_block_types", + "sample_size", + "transformer_layers_per_block", + "up_block_types", + "upcast_attention", + ] + kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} + kwargs.update(block_out_channels=block_out_channels) + kwargs.update(num_attention_heads=num_attention_heads) + kwargs.update(norm_num_groups=norm_num_groups) + + # Add controlnetxs-specific params + kwargs.update( + conditioning_channels=conditioning_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + time_embedding_input_dim=get_time_emb_input_dim(unet), + time_embedding_dim=get_time_emb_dim(unet), + time_embedding_mix=time_embedding_mix, + learn_embedding=learn_embedding, + base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + return cls(**kwargs) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + return self.control_model.attn_processors + + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + self.control_model.set_attn_processor(processor, _remove_lora) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.control_model.set_default_attn_processor() + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + self.control_model.set_attention_slice(slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (UNet2DConditionModel)): + if value: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() + + def forward( + self, + base_model: UNet2DConditionModel, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + base_model (`UNet2DConditionModel`): + The base unet model we want to control. + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # scale control strength + n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) + scale_list = torch.full((n_connections,), conditioning_scale) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = base_model.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + if self.config.learn_embedding: + ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) + base_temb = base_model.time_embedding(t_emb, timestep_cond) + interpolation_param = self.config.time_embedding_mix**0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = base_model.time_embedding(t_emb) + + # added time & text embeddings + aug_emb = None + + if base_model.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if base_model.config.class_embed_type == "timestep": + class_labels = base_model.time_proj(class_labels) + + class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) + temb = temb + class_emb + + if base_model.config.addition_embed_type is not None: + if base_model.config.addition_embed_type == "text": + aug_emb = base_model.add_embedding(encoder_hidden_states) + elif base_model.config.addition_embed_type == "text_image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = base_model.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = base_model.add_embedding(add_embeds) + elif base_model.config.addition_embed_type == "image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "image_hint": + raise NotImplementedError() + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + h_ctrl = h_base = sample + hs_base, hs_ctrl = [], [] + it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( + iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) + ) + scales = iter(scale_list) + + base_down_subblocks = to_sub_blocks(base_model.down_blocks) + ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) + base_mid_subblocks = to_sub_blocks([base_model.mid_block]) + ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) + base_up_subblocks = to_sub_blocks(base_model.up_blocks) + + # Cross Control + # 0 - conv in + h_base = base_model.conv_in(h_base) + h_ctrl = self.control_model.conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 1 - down + for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 2 - mid + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base + + # 3 - up + for i, m_base in enumerate(base_up_subblocks): + h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder + h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + + h_base = base_model.conv_norm_out(h_base) + h_base = base_model.conv_act(h_base) + h_base = base_model.conv_out(h_base) + + if not return_dict: + return h_base + + return ControlNetXSOutput(sample=h_base) + + def _make_zero_conv(self, in_channels, out_channels=None): + # keep running track of channels sizes + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + @torch.no_grad() + def _check_if_vae_compatible(self, vae: AutoencoderKL): + condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) + vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + compatible = condition_downscale_factor == vae_downscale_factor + return compatible, condition_downscale_factor, vae_downscale_factor + + +class SubBlock(nn.ModuleList): + """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. + Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. + """ + + def __init__(self, ms, *args, **kwargs): + if not is_iterable(ms): + ms = [ms] + super().__init__(ms, *args, **kwargs) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor, + cemb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """Iterate through children and pass correct information to each.""" + for m in self: + if isinstance(m, ResnetBlock2D): + x = m(x, temb) + elif isinstance(m, Transformer2DModel): + x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample + elif isinstance(m, Downsample2D): + x = m(x) + elif isinstance(m, Upsample2D): + x = m(x) + else: + raise ValueError( + f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" + ) + + return x + + +def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): + unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) + + +def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): + """Increase channels sizes to allow for additional concatted information from base model""" + r = unet.down_blocks[block_no].resnets[resnet_idx] + old_norm1, old_conv1 = r.norm1, r.conv1 + # norm + norm_args = "num_groups num_channels eps affine".split(" ") + for a in norm_args: + assert hasattr(old_norm1, a) + norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} + norm_kwargs["num_channels"] += by # surgery done here + # conv1 + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + + for a in conv1_args: + assert hasattr(old_conv1, a) + + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} + conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. + conv1_kwargs["in_channels"] += by # surgery done here + # conv_shortcut + # as we changed the input size of the block, the input and output sizes are likely different, + # therefore we need a conv_shortcut (simply adding won't work) + conv_shortcut_args_kwargs = { + "in_channels": conv1_kwargs["in_channels"], + "out_channels": conv1_kwargs["out_channels"], + # default arguments from resnet.__init__ + "kernel_size": 1, + "stride": 1, + "padding": 0, + "bias": True, + } + # swap old with new modules + unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) + unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) + unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here + + +def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): + """Increase channels sizes to allow for additional concatted information from base model""" + old_down = unet.down_blocks[block_no].downsamplers[0].conv + + args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + args.append("lora_layer") + + for a in args: + assert hasattr(old_down, a) + kwargs = {a: getattr(old_down, a) for a in args} + kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. + kwargs["in_channels"] += by # surgery done here + # swap old with new modules + unet.down_blocks[block_no].downsamplers[0].conv = ( + nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) + ) + unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here + + +def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): + """Increase channels sizes to allow for additional concatted information from base model""" + m = unet.mid_block.resnets[0] + old_norm1, old_conv1 = m.norm1, m.conv1 + # norm + norm_args = "num_groups num_channels eps affine".split(" ") + for a in norm_args: + assert hasattr(old_norm1, a) + norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} + norm_kwargs["num_channels"] += by # surgery done here + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} + conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. + conv1_kwargs["in_channels"] += by # surgery done here + # conv_shortcut + # as we changed the input size of the block, the input and output sizes are likely different, + # therefore we need a conv_shortcut (simply adding won't work) + conv_shortcut_args_kwargs = { + "in_channels": conv1_kwargs["in_channels"], + "out_channels": conv1_kwargs["out_channels"], + # default arguments from resnet.__init__ + "kernel_size": 1, + "stride": 1, + "padding": 0, + "bias": True, + } + # swap old with new modules + unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) + unet.mid_block.resnets[0].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.mid_block.resnets[0].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) + unet.mid_block.resnets[0].in_channels += by # surgery done here + + +def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): + def find_denominator(number, start): + if start >= number: + return number + while start != 0: + residual = number % start + if residual == 0: + return start + start -= 1 + + for block in [*unet.down_blocks, unet.mid_block]: + # resnets + for r in block.resnets: + if r.norm1.num_groups < max_num_group: + r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) + + if r.norm2.num_groups < max_num_group: + r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) + + # transformers + if hasattr(block, "attentions"): + for a in block.attentions: + if a.norm.num_groups < max_num_group: + a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) + + +def is_iterable(o): + if isinstance(o, str): + return False + try: + iter(o) + return True + except TypeError: + return False + + +def to_sub_blocks(blocks): + if not is_iterable(blocks): + blocks = [blocks] + + sub_blocks = [] + + for b in blocks: + if hasattr(b, "resnets"): + if hasattr(b, "attentions") and b.attentions is not None: + for r, a in zip(b.resnets, b.attentions): + sub_blocks.append([r, a]) + + num_resnets = len(b.resnets) + num_attns = len(b.attentions) + + if num_resnets > num_attns: + # we can have more resnets than attentions, so add each resnet as separate subblock + for i in range(num_attns, num_resnets): + sub_blocks.append([b.resnets[i]]) + else: + for r in b.resnets: + sub_blocks.append([r]) + + # upsamplers are part of the same subblock + if hasattr(b, "upsamplers") and b.upsamplers is not None: + for u in b.upsamplers: + sub_blocks[-1].extend([u]) + + # downsamplers are own subblock + if hasattr(b, "downsamplers") and b.downsamplers is not None: + for d in b.downsamplers: + sub_blocks.append([d]) + + return list(map(SubBlock, sub_blocks)) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/modules/control/units/xs_pipe.py b/modules/control/units/xs_pipe.py new file mode 100644 index 000000000..51e6d8191 --- /dev/null +++ b/modules/control/units/xs_pipe.py @@ -0,0 +1,1938 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.import_utils import is_invisible_watermark_available +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from modules.control.units.xs_model import ControlNetXSModel + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]: + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetXSModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( + vae + ) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + base_model=self.unet, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + +class StableDiffusionControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetXSModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( + vae + ) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + base_model=self.unet, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/modules/control/util.py b/modules/control/util.py new file mode 100644 index 000000000..9fe42877e --- /dev/null +++ b/modules/control/util.py @@ -0,0 +1,160 @@ +import os +import sys +import random +import cv2 +import numpy as np +import torch + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def dict2str(d: dict): + arr = [f'{name}: {d[name]}' for i, name in enumerate(d)] + return ' | '.join(arr) + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + _H, _W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) # noqa + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + y = np.zeros_like(x) + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + if random.uniform(0, 1) < 0.5: + y = 255 - y + return y < np.percentile(y, random.randrange(low, high)) + + +def resize_image(input_image, resolution): + H, W, _C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def ade_palette(): + """ADE20K palette that maps each class to RGB values.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def blend(images): + if images is None or len(images) == 0: + return images + y = np.zeros(images[0].shape, dtype=np.float32) + for img in images: + y = cv2.add(y, img.astype(np.float32)) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def decode_fourcc(cc): + cc_bytes = int(cc).to_bytes(4, byteorder=sys.byteorder) # convert code to a bytearray + cc_str = cc_bytes.decode() # decode byteaarray to a string + return cc_str diff --git a/modules/dml/__init__.py b/modules/dml/__init__.py index 3b9c8cf63..358885650 100644 --- a/modules/dml/__init__.py +++ b/modules/dml/__init__.py @@ -85,7 +85,7 @@ class OverrideItem(NamedTuple): message: Optional[str] opts_override_table = { - "diffusers_generator_device": OverrideItem("cpu", None, "DirectML does not support torch Generator API"), + "diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"), "diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"), "diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"), } diff --git a/modules/errors.py b/modules/errors.py index 0f770856f..122628bff 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -4,9 +4,10 @@ from rich.theme import Theme from rich.pretty import install as pretty_install from rich.traceback import install as traceback_install -from installer import log as installer_log +from installer import log as installer_log, setup_logging +setup_logging() log = installer_log console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", diff --git a/modules/extensions.py b/modules/extensions.py index 4e37e4611..3ad20acec 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -153,3 +153,4 @@ def list_extensions(): for dirname, path, is_builtin in extension_paths: extension = Extension(name=dirname, path=path, enabled=dirname not in disabled_extensions, is_builtin=is_builtin) extensions.append(extension) + shared.log.info(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}') diff --git a/modules/face_restoration.py b/modules/face_restoration.py index 4ae53d21b..55e1033c6 100644 --- a/modules/face_restoration.py +++ b/modules/face_restoration.py @@ -13,7 +13,5 @@ def restore_faces(np_image): face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] if len(face_restorers) == 0: return np_image - face_restorer = face_restorers[0] - return face_restorer.restore(np_image) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8dd589772..9c5e50063 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -16,7 +16,8 @@ type_of_gr_update = type(gr.update()) paste_fields = {} registered_param_bindings = [] -debug = shared.log.info if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PASTE') class ParamBinding: @@ -28,6 +29,7 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima self.source_tabname = source_tabname self.override_settings_component = override_settings_component self.paste_field_names = paste_field_names or [] + debug(f'ParamBinding: {vars(self)}') def reset(): @@ -112,6 +114,8 @@ def create_buttons(tabs_list): name = 'Inpaint' elif name == 'extras': name = 'Process' + elif name == 'control': + name = 'Control' buttons[tab] = gr.Button(f"➠ {name}", elem_id=f"{tab}_tab") return buttons @@ -121,7 +125,8 @@ def bind_buttons(buttons, send_image, send_generate_info): for tabname, button in buttons.items(): source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None source_tabname = send_generate_info if isinstance(send_generate_info, str) else None - register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)) + bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname) + register_paste_params_button(bindings) def register_paste_params_button(binding: ParamBinding): @@ -131,6 +136,9 @@ def register_paste_params_button(binding: ParamBinding): def connect_paste_params_buttons(): binding: ParamBinding for binding in registered_param_bindings: + if binding.tabname not in paste_fields: + debug(f"Not not registered: tab={binding.tabname}") + continue destination_image_component = paste_fields[binding.tabname]["init_img"] fields = paste_fields[binding.tabname]["fields"] override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] @@ -232,7 +240,8 @@ def parse_generation_parameters(x: str): res[k] = v except Exception: pass - res["Full quality"] = res.get('VAE', None) != 'TAESD' + if res.get('VAE', None) == 'TAESD': + res["Full quality"] = False debug(f"Parse prompt: {res}") return res diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 2d6acbafb..c8b6cd5c3 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -23,8 +23,8 @@ def train_hypernetwork(*args): Hypernetwork saved to {html.escape(filename)} """ return res, "" - except Exception: - raise + except Exception as e: + raise RuntimeError("Hypernetwork error") from e finally: shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) diff --git a/modules/images.py b/modules/images.py index 76c4a3cc6..87a6b43d0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,17 +1,17 @@ -import datetime import io import re import os +import sys import math import json import uuid +import queue import string import hashlib -import queue +import datetime import threading from pathlib import Path from collections import namedtuple -import pytz import numpy as np import piexif import piexif.helper @@ -19,7 +19,7 @@ from modules import sd_samplers, shared, script_callbacks, errors, paths -debug = errors.log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None try: from pi_heif import register_heif_opener register_heif_opener() @@ -210,12 +210,12 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0): def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image'): - # shared.log.debug(f'Image resize: mode={resize_mode} resolution={width}x{height} upscaler={upscaler_name}') + shared.log.debug(f'Image resize: mode={resize_mode} resolution={width}x{height} upscaler={upscaler_name} function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access """ Resizes an image with the specified resize_mode, width, and height. Args: resize_mode: The mode to use when resizing the image. - 0: No resie + 0: No resize 1: Resize the image to the specified width and height. 2: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. 3: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. @@ -242,9 +242,7 @@ def resize(im, w, h): im = im.resize((w, h), resample=Image.Resampling.LANCZOS) return im - if resize_mode == 0: - res = im.copy() - if width == 0 or height == 0: + if resize_mode == 0 or (im.width == width and im.height == height): res = im.copy() elif resize_mode == 1: res = resize(im, width, height) @@ -314,7 +312,7 @@ class FilenameGenerator: 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8], 'sampler': lambda self: self.p and self.p.sampler_name, - 'seed': lambda self: str(self.seed) if self.seed is not None else '', + 'seed': lambda self: self.seed and str(self.seed) or '', 'steps': lambda self: self.p and self.p.steps, 'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None", 'uuid': lambda self: str(uuid.uuid4()), @@ -322,8 +320,17 @@ class FilenameGenerator: default_time_format = '%Y%m%d%H%M%S' def __init__(self, p, seed, prompt, image, grid=False): + if p is None: + debug('Filename generator init skip') + else: + debug(f'Filename generator init: {seed} {prompt}') self.p = p - self.seed = seed + if seed is not None and seed > 0: + self.seed = seed + elif hasattr(p, 'all_seeds'): + self.seed = p.all_seeds[0] + else: + self.seed = 0 self.prompt = prompt self.image = image if not grid: @@ -335,7 +342,7 @@ def __init__(self, p, seed, prompt, image, grid=False): def hasprompt(self, *args): lower = self.prompt.lower() - if self.p is None or self.prompt is None: + if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None: return None outres = "" for arg in args: @@ -350,7 +357,7 @@ def hasprompt(self, *args): return outres def image_hash(self): - if self.image is None: + if getattr(self, 'image', None) is None: return None import base64 from io import BytesIO @@ -364,7 +371,7 @@ def prompt_full(self): return self.prompt_sanitize(self.prompt) def prompt_words(self): - if self.prompt is None: + if getattr(self, 'prompt', None) is None: return '' no_attention = re_attention.sub(r'\1', self.prompt) no_network = re_network.sub(r'\1', no_attention) @@ -374,7 +381,7 @@ def prompt_words(self): return self.prompt_sanitize(prompt) def prompt_no_style(self): - if self.p is None or self.prompt is None: + if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None: return None prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): @@ -385,6 +392,7 @@ def prompt_no_style(self): return self.prompt_sanitize(prompt_no_style) def datetime(self, *args): + import pytz time_datetime = datetime.datetime.now() time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format try: @@ -421,9 +429,11 @@ def sanitize(self, filename): if part in invalid_files: # reserved names [part := part.replace(word, '_') for word in invalid_files] # pylint: disable=expression-not-assigned newparts.append(part) - fn = Path(*newparts) - max_length = max(230, os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 230) - fn = str(fn)[:max_length-max(4, len(ext))].rstrip(invalid_suffix) + ext + fn = str(Path(*newparts)) + max_length = max(256 - len(ext), os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 256 - len(ext)) + while len(os.path.abspath(fn)) > max_length: + fn = fn[:-1] + fn += ext debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}') return fn @@ -455,9 +465,10 @@ def apply(self, x): break pattern, arg = m.groups() pattern_args.insert(0, arg) - fun = self.replacements.get(pattern.lower()) + fun = self.replacements.get(pattern.lower(), None) if fun is not None: try: + debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}') replacement = fun(self, *pattern_args) except Exception as e: replacement = None @@ -496,6 +507,8 @@ def atomically_save_image(): Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes while True: image, filename, extension, params, exifinfo, filename_txt = save_queue.get() + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: + file.write(exifinfo) fn = filename + extension filename = filename.strip() if extension[0] != '.': # add dot if missing @@ -507,7 +520,16 @@ def atomically_save_image(): image_format = 'JPEG' if shared.opts.image_watermark_enabled: image = set_watermark(image, shared.opts.image_watermark) - shared.log.debug(f'Saving: image="{fn}" type={image_format} size={image.width}x{image.height}') + size = os.path.getsize(fn) if os.path.exists(fn) else 0 + shared.log.debug(f'Saving: image="{fn}" type={image_format} resolution={image.width}x{image.height} size={size}') + # additional metadata saved in files + if shared.opts.save_txt and len(exifinfo) > 0: + try: + with open(filename_txt, "w", encoding="utf8") as file: + file.write(f"{exifinfo}\n") + shared.log.debug(f'Saving: text="{filename_txt}" len={len(exifinfo)}') + except Exception as e: + shared.log.warning(f'Image description save failed: {filename_txt} {e}') # actual save exifinfo = (exifinfo or "") if shared.opts.image_metadata else "" if image_format == 'PNG': @@ -517,7 +539,7 @@ def atomically_save_image(): try: image.save(fn, format=image_format, compress_level=6, pnginfo=pnginfo_data if shared.opts.image_metadata else None) except Exception as e: - shared.log.warning(f'Image save failed: {fn} {e}') + shared.log.error(f'Image save failed: file="{fn}" {e}') elif image_format == 'JPEG': if image.mode == 'RGBA': shared.log.warning('Saving RGBA image as JPEG: Alpha channel will be lost') @@ -528,7 +550,7 @@ def atomically_save_image(): try: image.save(fn, format=image_format, optimize=True, quality=shared.opts.jpeg_quality, exif=exif_bytes) except Exception as e: - shared.log.warning(f'Image save failed: {fn} {e}') + shared.log.error(f'Image save failed: file="{fn}" {e}') elif image_format == 'WEBP': if image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") @@ -536,36 +558,25 @@ def atomically_save_image(): try: image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes) except Exception as e: - shared.log.warning(f'Image save failed: {fn} {e}') + shared.log.error(f'Image save failed: file="{fn}" {e}') else: # shared.log.warning(f'Unrecognized image format: {extension} attempting save as {image_format}') try: image.save(fn, format=image_format, quality=shared.opts.jpeg_quality) except Exception as e: - shared.log.warning(f'Image save failed: {fn} {e}') - # additional metadata saved in files - if shared.opts.save_txt and len(exifinfo) > 0: - try: - with open(filename_txt, "w", encoding="utf8") as file: - file.write(f"{exifinfo}\n") - shared.log.debug(f'Saving: text="{filename_txt}"') - except Exception as e: - shared.log.warning(f'Image description save failed: {filename_txt} {e}') - with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: - file.write(exifinfo) + shared.log.error(f'Image save failed: file="{fn}" {e}') if shared.opts.save_log_fn != '' and len(exifinfo) > 0: fn = os.path.join(paths.data_path, shared.opts.save_log_fn) if not fn.endswith('.json'): fn += '.json' - entries = shared.readfile(fn) + entries = shared.readfile(fn, silent=True) idx = len(list(entries)) if idx == 0: entries = [] entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo } entries.append(entry) - shared.writefile(entries, fn, mode='w') - with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: - file.write(exifinfo) + shared.writefile(entries, fn, mode='w', silent=True) + shared.log.debug(f'Saving: json="{fn}" records={len(entries)}') save_queue.task_done() @@ -575,6 +586,7 @@ def atomically_save_image(): def save_image(image, path, basename='', seed=None, prompt=None, extension=shared.opts.samples_format, info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix='', save_to_dirs=None): # pylint: disable=unused-argument + debug(f'Save from function={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access if image is None: shared.log.warning('Image is none') return None, None @@ -608,6 +620,7 @@ def save_image(image, path, basename='', seed=None, prompt=None, extension=share if dirname is not None and len(dirname) > 0: os.makedirs(dirname, exist_ok=True) params.filename = namegen.sequence(params.filename, dirname, basename) + params.filename = namegen.sanitize(params.filename) # callbacks script_callbacks.before_image_saved_callback(params) exifinfo = params.pnginfo.get('UserComment', '') @@ -646,7 +659,8 @@ def save_video_atomic(images, filename, video_type: str = 'none', duration: floa for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) - shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc}') + size = os.path.getsize(filename) + shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}') if video_type.lower() == 'gif' or video_type.lower() == 'png': append = images.copy() image = append.pop(0) @@ -661,20 +675,33 @@ def save_video_atomic(images, filename, video_type: str = 'none', duration: floa duration = 1000.0 * duration / frames, loop = 0 if loop else 1, ) - shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop}') + size = os.path.getsize(filename) + shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}') -def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3): +def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False): if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none': return image = images[0] - namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image) - if filename is None: + if p is not None: + namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image) + else: + namegen = FilenameGenerator(None, seed=0, prompt='', image=image) + if filename is None and p is not None: filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]") - filename = namegen.sanitize(os.path.join(shared.opts.outdir_video, filename)) + filename = os.path.join(shared.opts.outdir_video, filename) filename = namegen.sequence(filename, shared.opts.outdir_video, '') - filename = f'{filename}.{video_type.lower()}' - threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start() + else: + if os.pathsep not in filename: + filename = os.path.join(shared.opts.outdir_video, filename) + if not filename.lower().endswith(video_type.lower()): + filename += f'.{video_type.lower()}' + filename = namegen.sanitize(filename) + if not sync: + threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start() + else: + save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change) + return filename def safe_decode_string(s: bytes): diff --git a/modules/img2img.py b/modules/img2img.py index e0fd159b0..4c0ce910b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -9,6 +9,10 @@ from modules.memstats import memory_stats +debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PROCESS') + + def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args): shared.log.debug(f'batch: {input_files}|{input_dir}|{output_dir}|{inpaint_mask_dir}') processing.fix_seed(p) @@ -79,16 +83,16 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) if proc is None: proc = processing.process_images(p) for n, (image, image_file) in enumerate(itertools.zip_longest(proc.images,batch_image_files)): - basename, ext = os.path.splitext(os.path.basename(image_file)) - ext = ext[1:] + basename = '' + if shared.opts.use_original_name_batch: + forced_filename, ext = os.path.splitext(os.path.basename(image_file)) + else: + forced_filename = None + ext = shared.opts.samples_format if len(proc.images) > 1: - if shared.opts.batch_frame_mode: # SBM Frames are numbered globally. - basename = f'{basename}-{n + i}' - else: # Images are numbered per rept. - basename = f'{basename}-{n}' - if not shared.opts.use_original_name_batch: + basename = f'{n + i}' if shared.opts.batch_frame_mode else f'{n}' + else: basename = '' - ext = shared.opts.samples_format if output_dir == '': output_dir = shared.opts.outdir_img2img_samples if not save_normally: @@ -96,17 +100,47 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) geninfo, items = images.read_info_from_image(image) for k, v in items.items(): image.info[k] = v - images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=None) + images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename) shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch') -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, latent_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, full_quality: bool, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, diffusers_guidance_rescale: float, refiner_steps: int, refiner_start: float, clip_skip: int, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_files: list, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): # pylint: disable=unused-argument +def img2img(id_task: str, mode: int, + prompt, negative_prompt, prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps, + sampler_index, latent_index, + mask_blur, mask_alpha, + inpainting_fill, + full_quality, restore_faces, tiling, + n_iter, batch_size, + cfg_scale, image_cfg_scale, + diffusers_guidance_rescale, + refiner_steps, + refiner_start, + clip_skip, + denoising_strength, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + selected_scale_tab, + height, width, + scale_by, + resize_mode, resize_name, + inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, + img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, + hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + override_settings_texts, + *args): # pylint: disable=unused-argument if shared.sd_model is None: shared.log.warning('Model not loaded') return [], '', '', 'Error: model not loaded' - shared.log.debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}|latent_index={latent_index}|mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}') + debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}|latent_index={latent_index}|mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}') if mode == 5: if img2img_batch_files is None or len(img2img_batch_files) == 0: @@ -156,6 +190,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s image = init_img_inpaint mask = init_mask_inpaint else: + shared.log.error(f'Image processing unknown mode: {mode}') image = None mask = None if image is not None: @@ -194,14 +229,18 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask_blur=mask_blur, inpainting_fill=inpainting_fill, resize_mode=resize_mode, + resize_name=resize_name, denoising_strength=denoising_strength, image_cfg_scale=image_cfg_scale, diffusers_guidance_rescale=diffusers_guidance_rescale, refiner_steps=refiner_steps, refiner_start=refiner_start, - inpaint_full_res=inpaint_full_res, + inpaint_full_res=inpaint_full_res != 0, inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, + hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold, + hdr_center=hdr_center, hdr_channel_shift=hdr_channel_shift, hdr_full_shift=hdr_full_shift, + hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry, override_settings=override_settings, ) if selected_scale_tab == 1 and resize_mode != 0: diff --git a/modules/intel/ipex/__init__.py b/modules/intel/ipex/__init__.py index cbadd14fc..c78547915 100644 --- a/modules/intel/ipex/__init__.py +++ b/modules/intel/ipex/__init__.py @@ -4,14 +4,12 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init -from .diffusers import ipex_diffusers # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -92,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None @@ -114,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -125,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() @@ -140,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: + # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True @@ -158,16 +156,14 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - if hasattr(torch.xpu, 'getDeviceIdListForCard'): - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard - else: - torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card - torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() - attention_init() - ipex_diffusers() + if not torch.xpu.has_fp64_dtype(): + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass except Exception as e: return False, e return True, None diff --git a/modules/intel/ipex/attention.py b/modules/intel/ipex/attention.py index 094ea5104..ce795771b 100644 --- a/modules/intel/ipex/attention.py +++ b/modules/intel/ipex/attention.py @@ -1,45 +1,90 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = input.element_size() - slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers + +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 6)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size block_size = batch_size_attention * slice_block_size split_slice_size = batch_size_attention - if block_size > 4: + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: do_split = True - #Find something divisible with the input_tokens - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_block_size_2) + if split_2_slice_size * slice_block_size_2 > attention_slice_rate: + slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_block_size_3) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention split_2_slice_size = input_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the input_tokens - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False + do_split = False + do_split_2 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_block_size_2 = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_block_size_2) + + return do_split, do_split_2, split_slice_size, split_2_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + do_split, do_split_2, split_slice_size, split_2_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM if do_split: + batch_size_attention, input_tokens = input.shape[0], input.shape[1] hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -64,51 +109,12 @@ def torch_bmm(input, mat2, *, out=None): return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_four = query.shape - shape_one = 1 - no_shape_one = True - else: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - no_shape_one = False - - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - - block_multiply = query.element_size() - slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - if block_size > 6: - do_split = True - #Find something divisible with the shape_one - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_2_slice_size = query_tokens - if split_slice_size * slice_block_size > 6: - slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the batch_size_attention - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + # Slice SDPA if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -117,7 +123,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - if no_shape_one: + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( query[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_2:end_idx_2], @@ -125,38 +142,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal ) - else: - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx, start_idx_2:end_idx_2], - key[:, start_idx:end_idx, start_idx_2:end_idx_2], - value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) else: - if no_shape_one: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx], - key[:, start_idx:end_idx], - value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) else: return original_scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal ) return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/modules/intel/ipex/diffusers.py b/modules/intel/ipex/diffusers.py index 005ee49f0..6bc6aae31 100644 --- a/modules/intel/ipex/diffusers.py +++ b/modules/intel/ipex/diffusers.py @@ -1,10 +1,24 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error +import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + + class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -61,12 +75,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a split_2_slice_size = query_tokens if block_size > 4: do_split_2 = True - #Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break + split_2_slice_size = find_slice_size(split_2_slice_size, slice_block_size) else: do_split_2 = False diff --git a/modules/intel/ipex/gradscaler.py b/modules/intel/ipex/gradscaler.py index 530212101..6eb56bc2b 100644 --- a/modules/intel/ipex/gradscaler.py +++ b/modules/intel/ipex/gradscaler.py @@ -5,6 +5,7 @@ # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device ) diff --git a/modules/intel/ipex/hijacks.py b/modules/intel/ipex/hijacks.py index cdd83cea8..4573b7f7f 100644 --- a/modules/intel/ipex/hijacks.py +++ b/modules/intel/ipex/hijacks.py @@ -65,7 +65,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -#Embedding BF16 +# Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -73,7 +73,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) -#Latent antialias: +# Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -93,6 +93,32 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +if torch.xpu.has_fp64_dtype(): + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 64 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + +# dtype errors: +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + +@property def is_cuda(self): return self.device.type == 'xpu' @@ -131,12 +157,16 @@ def ipex_hijacks(): lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) + if hasattr(torch.xpu, "Generator"): + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + else: + CondFunc('torch.Generator', + lambda orig_func, device=None: orig_func(return_xpu(device)), + lambda orig_func, device=None: check_device(device)) - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - - #TiledVAE and ControlNet: + # TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -148,47 +178,51 @@ def ipex_hijacks(): bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - #Functions with dtype errors: - #Original backend: + # Functions with dtype errors: CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #Hypernetwork training: + # Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #BF16: + # BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - #SwinIR BF16: + # SwinIR BF16: CondFunc('torch.nn.functional.pad', lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): if not torch.xpu.has_fp64_dtype(): CondFunc('torch.from_numpy', lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), lambda orig_func, ndarray: ndarray.dtype == float) - #Broken functions when torch.cuda.is_available is True: - #Pin Memory: + # Broken functions when torch.cuda.is_available is True: + # Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) - #Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + # Functions that make compile mad with CondFunc: torch.nn.DataParallel = DummyDataParallel + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve + torch.backends.cuda.sdp_kernel = return_null_context torch.UntypedStorage.is_cuda = is_cuda + torch.nn.functional.interpolate = interpolate - torch.backends.cuda.sdp_kernel = return_null_context + torch.linalg.solve = linalg_solve + + torch.bmm = torch_bmm + torch.cat = torch_cat + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/modules/intel/openvino/__init__.py b/modules/intel/openvino/__init__.py index 500bfd73b..0c61a9757 100644 --- a/modules/intel/openvino/__init__.py +++ b/modules/intel/openvino/__init__.py @@ -1,20 +1,35 @@ import os import sys import torch +import nncf + from openvino.frontend import FrontEndManager from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder from openvino.frontend.pytorch.torchdynamo.partition import Partitioner from openvino.runtime import Core, Type, PartialShape, serialize + from torch._dynamo.backends.common import fake_tensor_unsupported from torch._dynamo.backends.registry import register_backend from torch.fx.experimental.proxy_tensor import make_fx from torch.fx import GraphModule from torch.utils._pytree import tree_flatten + from types import MappingProxyType from hashlib import sha256 import functools + from modules import shared, devices +NNCFNodeName = str +def get_node_by_name(self, name: NNCFNodeName) -> nncf.common.graph.NNCFNode: + node_ids = self._node_name_to_node_id_map.get(name, None) + if node_ids is None: + raise RuntimeError("Could not find a node {} in NNCFGraph!".format(name)) + + node_key = f"{node_ids[0]} {name}" + return self._nodes[node_key] +nncf.common.graph.NNCFGraph.get_node_by_name = get_node_by_name + def BUILD_MAP_UNPACK(self, inst): items = self.popn(inst.argval) # ensure everything is a dict @@ -50,18 +65,9 @@ def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: st self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache, "model_hash_str": model_hash_str} self.file_name = file_name - self.perm_fallback = False def __call__(self, *args): - #if self.perm_fallback: - # return self.gm(*args) - - #try: result = openvino_execute(self.gm, *args, executor_parameters=self.executor_parameters, partition_id=self.partition_id, file_name=self.file_name) - #except Exception: - # self.perm_fallback = True - # return self.gm(*args) - return result def get_device(): @@ -102,12 +108,6 @@ def get_openvino_device(): except Exception: return f"OpenVINO {get_device()}" -def cache_root_path(): - cache_root = "./cache/" - if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None: - cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR") - return cache_root - def cached_model_name(model_hash_str, device, args, cache_root, reversed = False): if model_hash_str is None: return None @@ -181,7 +181,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na core = Core() device = get_device() - cache_root = cache_root_path() + cache_root = shared.opts.openvino_cache_path if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"): om = core.read_model(file_name + ".xml") @@ -229,12 +229,14 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype]) om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape))) om.validate_nodes_and_infer_types() + if shared.opts.nncf_compress_weights: + om = nncf.compress_weights(om) if model_hash_str is not None: core.set_property({'CACHE_DIR': cache_root + '/blob'}) - compiled = core.compile_model(om, device) - return compiled + compiled_model = core.compile_model(om, device) + return compiled_model def openvino_compile_cached_model(cached_model_path, *example_inputs): core = Core() @@ -255,8 +257,10 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs): om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype]) om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape))) om.validate_nodes_and_infer_types() + if shared.opts.nncf_compress_weights: + om = nncf.compress_weights(om) - core.set_property({'CACHE_DIR': cache_root_path() + '/blob'}) + core.set_property({'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}) compiled_model = core.compile_model(om, get_device()) @@ -351,41 +355,44 @@ def openvino_fx(subgraph, example_inputs): # Check if the model was fully supported and already cached example_inputs.reverse() inputs_reversed = True - maybe_fs_cached_name = cached_model_name(model_hash_str + "_fs", get_device(), example_inputs, cache_root_path()) + maybe_fs_cached_name = cached_model_name(model_hash_str + "_fs", get_device(), example_inputs, shared.opts.openvino_cache_path) if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"): - if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name): - example_inputs_reordered = [] - if (os.path.isfile(maybe_fs_cached_name + ".txt")): - f = open(maybe_fs_cached_name + ".txt", "r") - for input_data in example_inputs: - shape = f.readline() - if (str(input_data.size()) != shape): - for idx1, input_data1 in enumerate(example_inputs): - if (str(input_data1.size()).strip() == str(shape).strip()): - example_inputs_reordered.append(example_inputs[idx1]) - example_inputs = example_inputs_reordered - - # Model is fully supported and already cached. Run the cached OV model directly. - compiled_model = openvino_compile_cached_model(maybe_fs_cached_name, *example_inputs) - - def _call(*args): - if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name): - args_reordered = [] - if (os.path.isfile(maybe_fs_cached_name + ".txt")): - f = open(maybe_fs_cached_name + ".txt", "r") - for input_data in args: - shape = f.readline() - if (str(input_data.size()) != shape): - for idx1, input_data1 in enumerate(args): - if (str(input_data1.size()).strip() == str(shape).strip()): - args_reordered.append(args[idx1]) - args = args_reordered - - res = execute_cached(compiled_model, *args) - shared.compiled_model_state.partition_id = shared.compiled_model_state.partition_id + 1 - return res - return _call + example_inputs_reordered = [] + if (os.path.isfile(maybe_fs_cached_name + ".txt")): + f = open(maybe_fs_cached_name + ".txt", "r") + for input_data in example_inputs: + shape = f.readline() + if (str(input_data.size()) != shape): + for idx1, input_data1 in enumerate(example_inputs): + if (str(input_data1.size()).strip() == str(shape).strip()): + example_inputs_reordered.append(example_inputs[idx1]) + example_inputs = example_inputs_reordered + + # Deleting unused subgraphs doesn't do anything, so we cast it down to fp8 + subgraph = subgraph.to(dtype=torch.float8_e4m3fn) + devices.torch_gc(force=True) + + # Model is fully supported and already cached. Run the cached OV model directly. + compiled_model = openvino_compile_cached_model(maybe_fs_cached_name, *example_inputs) + + def _call(*args): + if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name): + args_reordered = [] + if (os.path.isfile(maybe_fs_cached_name + ".txt")): + f = open(maybe_fs_cached_name + ".txt", "r") + for input_data in args: + shape = f.readline() + if (str(input_data.size()) != shape): + for idx1, input_data1 in enumerate(args): + if (str(input_data1.size()).strip() == str(shape).strip()): + args_reordered.append(args[idx1]) + args = args_reordered + + res = execute_cached(compiled_model, *args) + shared.compiled_model_state.partition_id = shared.compiled_model_state.partition_id + 1 + return res + return _call else: os.environ.setdefault('OPENVINO_TORCH_MODEL_CACHING', "0") maybe_fs_cached_name = None diff --git a/modules/k-diffusion b/modules/k-diffusion index 045515774..cc49cf618 160000 --- a/modules/k-diffusion +++ b/modules/k-diffusion @@ -1 +1 @@ -Subproject commit 045515774882014cc14c1ba2668ab5bad9cbf7c0 +Subproject commit cc49cf6182284e577e896943f8e29c7c9d1a7f2c diff --git a/modules/loader.py b/modules/loader.py index 8b06f09ac..31cadcd56 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -1,11 +1,13 @@ from __future__ import annotations import re +import sys import logging import warnings import urllib3 from modules import timer, errors initialized = False +errors.install() logging.getLogger("DeepSpeed").disabled = True # os.environ.setdefault('OMP_NUM_THREADS', 1) # os.environ.setdefault('MKL_NUM_THREADS', 1) @@ -55,3 +57,12 @@ errors.log.debug(f'Detected: cores={cores} affinity={affinity} set threads={threads}') except Exception: pass + +try: # fix changed import in torchvision 0.17+, which breaks basicsr + import torchvision.transforms.functional_tensor # pylint: disable=unused-import, ungrouped-imports +except ImportError: + try: + import torchvision.transforms.functional as functional + sys.modules["torchvision.transforms.functional_tensor"] = functional + except ImportError: + pass # shrug... diff --git a/modules/lora b/modules/lora index 0908c5414..1a36f9dc6 160000 --- a/modules/lora +++ b/modules/lora @@ -1 +1 @@ -Subproject commit 0908c5414dadfe2185669e2d657d542161647b79 +Subproject commit 1a36f9dc65fb3be8baa9dcbde832048cc5644efa diff --git a/modules/lowvram.py b/modules/lowvram.py index cb684acd2..5afb7e7c8 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -52,7 +52,7 @@ def first_stage_model_decode_wrap(z): return first_stage_model_decode(z) # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field - if hasattr(sd_model.cond_stage_model, 'model'): + if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'): sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then @@ -73,7 +73,7 @@ def first_stage_model_decode_wrap(z): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model - if hasattr(sd_model.cond_stage_model, 'model'): + if hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model'): sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer del sd_model.cond_stage_model.transformer diff --git a/modules/modelloader.py b/modules/modelloader.py index 0e8254024..9cd05dc82 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse import PIL.Image as Image import rich.progress as p -from modules import shared +from modules import shared, errors from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -68,6 +68,7 @@ def download_civit_meta(model_path: str, model_id): return msg except Exception as e: msg = f'CivitAI download error: id={model_id} url={url} file={fn} {e}' + errors.display(e, 'CivitAI download error') shared.log.error(msg) return msg return f'CivitAI download error: id={model_id} url={url} code={r.status_code}' @@ -100,7 +101,7 @@ def download_civit_preview(model_path: str, preview_url: str): except Exception as e: os.remove(preview_file) res += f' error={e}' - shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} {e}') + shared.log.error(f'CivitAI download error: url={preview_url} file={preview_file} written={written} {e}') shared.state.end() if img is None: return res @@ -111,7 +112,7 @@ def download_civit_preview(model_path: str, preview_url: str): download_pbar = None -def download_civit_model_thread(model_name, model_url, model_path, model_type, preview): +def download_civit_model_thread(model_name, model_url, model_path, model_type, preview, token): import hashlib sha256 = hashlib.sha256() sha256.update(model_name.encode('utf-8')) @@ -135,7 +136,9 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, p if os.path.isfile(temp_file): starting_pos = os.path.getsize(temp_file) res += f' resume={round(starting_pos/1024/1024)}Mb' - headers = {'Range': f'bytes={starting_pos}-'} + headers['Range'] = f'bytes={starting_pos}-' + if token is not None and len(token) > 0: + headers['Authorization'] = f'Bearer {token}' r = shared.req(model_url, headers=headers, stream=True) total_size = int(r.headers.get('content-length', 0)) @@ -175,9 +178,9 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, p return res -def download_civit_model(model_url: str, model_name: str, model_path: str, model_type: str, preview): +def download_civit_model(model_url: str, model_name: str, model_path: str, model_type: str, preview, token: str = None): import threading - thread = threading.Thread(target=download_civit_model_thread, args=(model_name, model_url, model_path, model_type, preview)) + thread = threading.Thread(target=download_civit_model_thread, args=(model_name, model_url, model_path, model_type, preview, token)) thread.start() return f'CivitAI download: name={model_name} url={model_url} path={model_path}' @@ -232,8 +235,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config shared.log.error(f"Diffusers download error: {hub_id} {err}") return None try: - # TODO diffusers is this real error? - model_info_dict = hf.model_info(hub_id).cardData if pipeline_dir is not None else None # pylint: disable=no-member + model_info_dict = hf.model_info(hub_id).cardData if pipeline_dir is not None else None except Exception: model_info_dict = None if model_info_dict is not None and "prior" in model_info_dict: # some checkpoints need to be downloaded as "hidden" as they just serve as pre- or post-pipelines of other pipelines @@ -290,8 +292,9 @@ def load_diffusers_models(model_path: str, command_path: str = None, clear=True) if os.path.exists(os.path.join(folder, 'hidden')): continue output.append(name) - except Exception as e: - shared.log.error(f"Error analyzing diffusers model: {folder} {e}") + except Exception: + # shared.log.error(f"Error analyzing diffusers model: {folder} {e}") + pass except Exception as e: shared.log.error(f"Error listing diffusers: {place} {e}") shared.log.debug(f'Scanning diffusers cache: {model_path} {command_path} items={len(output)} time={time.time()-t0:.2f}') @@ -335,6 +338,27 @@ def load_reference(name: str): return True +def load_civitai(model: str, url: str): + from modules import sd_models + name, _ext = os.path.splitext(model) + info = sd_models.get_closet_checkpoint_match(name) + if info is not None: + shared.log.debug(f'Reference model: {name}') + return name # already downloaded + else: + shared.log.debug(f'Reference model: {name} download start') + download_civit_model_thread(model_name=model, model_url=url, model_path='', model_type='safetensors', preview=None, token=None) + shared.log.debug(f'Reference model: {name} download complete') + sd_models.list_models() + info = sd_models.get_closet_checkpoint_match(name) + if info is not None: + shared.log.debug(f'Reference model: {name}') + return name # already downloaded + else: + shared.log.debug(f'Reference model: {name} not found') + return None + + cache_folders = {} cache_last = 0 cache_time = 1 @@ -605,3 +629,4 @@ def load_upscalers(): shared.sd_upscalers = sorted(datas, key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "") # Special case for UpscalerNone keeps it at the beginning of the list. t1 = time.time() shared.log.debug(f"Load upscalers: total={len(shared.sd_upscalers)} downloaded={len([x for x in shared.sd_upscalers if x.data_path is not None and os.path.isfile(x.data_path)])} user={len([x for x in shared.sd_upscalers if x.custom])} time={t1-t0:.2f} {names}") + return [x.name for x in shared.sd_upscalers] diff --git a/modules/patches.py b/modules/patches.py index a6c3a25bb..cff6bfd64 100644 --- a/modules/patches.py +++ b/modules/patches.py @@ -50,6 +50,7 @@ def undo(key, obj, field): original_func = originals[key].pop(patch_key) setattr(obj, field, original_func) + return None def original(key, obj, field): diff --git a/modules/paths.py b/modules/paths.py index d7a9e8691..9347c51a0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -32,7 +32,8 @@ sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = cli.ckpt or os.path.join(script_path, 'model.ckpt') # not used default_sd_model_file = sd_model_file # not used -debug = log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PATH') paths = {} if os.environ.get('SD_PATH_DEBUG', None) is not None: diff --git a/modules/postprocess/sdupscaler_model.py b/modules/postprocess/sdupscaler_model.py index f4c069925..0d73b8fa9 100644 --- a/modules/postprocess/sdupscaler_model.py +++ b/modules/postprocess/sdupscaler_model.py @@ -45,7 +45,7 @@ def do_upscale(self, img: Image.Image, selected_model): if model is None: return img seeds = [torch.randint(0, 2 ** 32, (1,)).item() for _ in range(1)] - generator_device = devices.cpu if shared.opts.diffusers_generator_device == "cpu" else devices.device + generator_device = devices.cpu if shared.opts.diffusers_generator_device == "CPU" else devices.device generator = [torch.Generator(generator_device).manual_seed(s) for s in seeds] args = { 'prompt': '', diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 82de57578..150c641dc 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -13,24 +13,29 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp shared.state.begin('extras') image_data = [] image_names = [] + image_fullnames = [] image_ext = [] outputs = [] params = {} if extras_mode == 1: - shared.log.debug(f'process: mode=batch folder={image_folder}') for img in image_folder: if isinstance(img, Image.Image): image = img fn = '' ext = None else: - image = Image.open(os.path.abspath(img.name)) + try: + image = Image.open(os.path.abspath(img.name)) + except Exception as e: + shared.log.error(f'Failed to open image: file="{img.name}" {e}') + continue fn, ext = os.path.splitext(img.orig_name) + image_fullnames.append(img.name) image_data.append(image) image_names.append(fn) image_ext.append(ext) + shared.log.debug(f'Process: mode=batch inputs={len(image_folder)} images={len(image_data)}') elif extras_mode == 2: - shared.log.debug(f'process: mode=folder folder={input_dir}') assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert input_dir, 'input directory not selected' image_list = shared.listfiles(input_dir) @@ -38,11 +43,13 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp try: image = Image.open(filename) except Exception as e: - shared.log.error(f'Failed to open image: {filename} {e}') + shared.log.error(f'Failed to open image: file="{filename}" {e}') continue + image_fullnames.append(filename) image_data.append(image) image_names.append(filename) image_ext.append(None) + shared.log.debug(f'Process: mode=folder inputs={input_dir} files={len(image_list)} images={len(image_data)}') else: image_data.append(image) image_names.append(None) @@ -51,8 +58,9 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp outpath = output_dir else: outpath = opts.outdir_samples or opts.outdir_extras_samples + processed_images = [] for image, name, ext in zip(image_data, image_names, image_ext): # pylint: disable=redefined-argument-from-local - shared.log.debug(f'process: image={image} {args}') + shared.log.debug(f'Process: image={image} {args}') infotext = '' if shared.state.interrupted: shared.log.debug('Postprocess interrupted') @@ -74,11 +82,13 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp infotext = items['parameters'] + ', ' infotext = infotext + ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) pp.image.info["postprocessing"] = infotext + processed_images.append(pp.image) if save_output: images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=ext or opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=pp.image.info, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) image.close() + scripts.scripts_postproc.postprocess(processed_images, args) devices.torch_gc() return outputs, infotext, params diff --git a/modules/processing.py b/modules/processing.py index b04388d8f..e749eb59c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,7 +11,7 @@ import torch import numpy as np import cv2 -from PIL import Image, ImageFilter, ImageOps +from PIL import Image, ImageOps from skimage import exposure from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion @@ -44,6 +44,9 @@ opt_C = 4 opt_f = 8 +debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PROCESS') + def setup_color_correction(image): shared.log.debug("Calibrating color correction.") @@ -64,10 +67,11 @@ def apply_overlay(image: Image, paste_loc, index, overlays): overlay = overlays[index] if paste_loc is not None: x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(2, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + if image.width != w or image.height != h or x != 0 or y != 0: + base_image = Image.new('RGBA', (overlay.width, overlay.height)) + image = images.resize_image(2, image, w, h) + base_image.paste(image, (x, y)) + image = base_image image = image.convert('RGBA') image.alpha_composite(overlay) image = image.convert('RGB') @@ -121,7 +125,7 @@ class StableDiffusionProcessing: """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, latent_sampler: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, image_cfg_scale: float = None, clip_skip: int = 1, width: int = 512, height: int = 512, full_quality: bool = True, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, diffusers_guidance_rescale: float = 0.7, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 3.5, hdr_center: bool = False, hdr_channel_shift: float = 0.8, hdr_full_shift: float = 0.8, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): # pylint: disable=unused-argument + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, latent_sampler: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, image_cfg_scale: float = None, clip_skip: int = 1, width: int = 512, height: int = 512, full_quality: bool = True, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, diffusers_guidance_rescale: float = 0.7, resize_mode: int = 0, resize_name: str = 'None', scale_by: float = 0, selected_scale_tab: int = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 3.5, hdr_center: bool = False, hdr_channel_shift: float = 0.8, hdr_full_shift: float = 0.8, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundry: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): # pylint: disable=unused-argument self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids self.prompt: str = prompt @@ -140,9 +144,10 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.steps: int = steps self.hr_second_pass_steps = 0 self.cfg_scale: float = cfg_scale + self.scale_by: float = scale_by self.image_cfg_scale = image_cfg_scale self.diffusers_guidance_rescale = diffusers_guidance_rescale - if (devices.backend == "ipex" or shared.cmd_opts.use_openvino) and width == 1024 and height == 1024: + if devices.backend == "ipex" and width == 1024 and height == 1024 and os.environ.get('DISABLE_IPEX_1024_WA', None) is None: width = 1080 height = 1080 self.width: int = width @@ -174,14 +179,27 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.all_subseeds = None self.clip_skip = clip_skip self.iteration = 0 + self.is_control = False self.is_hr_pass = False self.is_refiner_pass = False self.hr_force = False self.enable_hr = None + self.hr_scale = None + self.hr_upscaler = None + self.hr_resize_x = 0 + self.hr_resize_y = 0 + self.hr_upscale_to_x = 0 + self.hr_upscale_to_y = 0 + self.truncate_x = 0 + self.truncate_y = 0 + self.applied_old_hires_behavior_to = None self.refiner_steps = 5 self.refiner_start = 0 + self.refiner_prompt = '' + self.refiner_negative = '' self.ops = [] - self.resize_mode: int = 0 + self.resize_mode: int = resize_mode + self.resize_name: str = resize_name self.ddim_discretize = shared.opts.ddim_discretize self.s_min_uncond = shared.opts.s_min_uncond self.s_churn = shared.opts.s_churn @@ -192,7 +210,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.s_tmax = float('inf') # not representable as a standard ui option shared.opts.data['clip_skip'] = clip_skip self.task_args = {} - # TODO a1111 compatibility items + # a1111 compatibility items self.refiner_switch_at = 0 self.hr_prompt = '' self.all_hr_prompts = [] @@ -213,6 +231,11 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.hdr_maximize = hdr_maximize self.hdr_max_center = hdr_max_center self.hdr_max_boundry = hdr_max_boundry + self.scheduled_prompt: bool = False + self.prompt_embeds = [] + self.positive_pooleds = [] + self.negative_embeds = [] + self.negative_pooleds = [] @property @@ -356,8 +379,8 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", self.subseed_strength = p.subseed_strength self.info = info self.comments = comments - self.width = p.width - self.height = p.height + self.width = p.width if hasattr(p, 'width') else (self.images[0].width if len(self.images) > 0 else 0) + self.height = p.height if hasattr(p, 'height') else (self.images[0].height if len(self.images) > 0 else 0) self.sampler_name = p.sampler_name self.cfg_scale = p.cfg_scale self.image_cfg_scale = p.image_cfg_scale @@ -423,7 +446,7 @@ def js(self): "styles": self.styles, "job_timestamp": self.job_timestamp, "clip_skip": self.clip_skip, - "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, + # "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, } return json.dumps(obj) @@ -542,13 +565,13 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No if index is None: index = position_in_batch + iteration * p.batch_size if all_prompts is None: - all_prompts = p.all_prompts + all_prompts = p.all_prompts or [p.prompt] if all_negative_prompts is None: - all_negative_prompts = p.all_negative_prompts + all_negative_prompts = p.all_negative_prompts or [p.negative_prompt] if all_seeds is None: - all_seeds = p.all_seeds + all_seeds = p.all_seeds or [p.seed] if all_subseeds is None: - all_subseeds = p.all_subseeds + all_subseeds = p.all_subseeds or [p.subseed] while len(all_prompts) <= index: all_prompts.append(all_prompts[-1]) while len(all_seeds) <= index: @@ -566,15 +589,13 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No "Seed": all_seeds[index], "Sampler": p.sampler_name, "CFG scale": p.cfg_scale, - "Size": f"{p.width}x{p.height}", + "Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None, "Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None, "Index": f'{p.iteration + 1}x{index + 1}' if (p.n_iter > 1 or p.batch_size > 1) and index >= 0 else None, "Parser": shared.opts.prompt_attention, "Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''), "Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash), "VAE": (None if not shared.opts.add_model_name_to_info or modules.sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(modules.sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD', - "Variation seed": None if p.subseed_strength == 0 else all_subseeds[index], - "Variation strength": None if p.subseed_strength == 0 else p.subseed_strength, "Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}", "Clip skip": p.clip_skip if p.clip_skip > 1 else None, "Prompt2": p.refiner_prompt if len(p.refiner_prompt) > 0 else None, @@ -590,6 +611,9 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No } if 'txt2img' in p.ops: pass + if shared.backend == shared.Backend.ORIGINAL: + args["Variation seed"] = None if p.subseed_strength == 0 else all_subseeds[index], + args["Variation strength"] = None if p.subseed_strength == 0 else p.subseed_strength, if 'hires' in p.ops or 'upscale' in p.ops: args["Second pass"] = p.enable_hr args["Hires force"] = p.hr_force @@ -618,11 +642,11 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No args['Resize scale'] = getattr(p, 'scale_by', None) args["Mask blur"] = p.mask_blur if getattr(p, 'mask', None) is not None and getattr(p, 'mask_blur', 0) > 0 else None args["Denoising strength"] = getattr(p, 'denoising_strength', None) + if args["Size"] is None: + args["Size"] = args["Init image size"] # lookup by index if getattr(p, 'resize_mode', None) is not None: - RESIZE_MODES = ["None", "Resize fixed", "Crop and resize", "Resize and fill", "Latent upscale"] - args['Resize mode'] = RESIZE_MODES[p.resize_mode] - # TODO missing-by-index: inpainting_fill, inpaint_full_res, inpainting_mask_invert + args['Resize mode'] = shared.resize_modes[p.resize_mode] if 'face' in p.ops: args["Face restoration"] = shared.opts.face_restoration_model if 'color' in p.ops: @@ -667,12 +691,15 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No def process_images(p: StableDiffusionProcessing) -> Processed: + debug(f'Process images: {vars(p)}') if not hasattr(p.sd_model, 'sd_checkpoint_info'): return None if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner): p.scripts.before_process(p) stored_opts = {} for k, v in p.override_settings.copy().items(): + if shared.opts.data.get(k, None) is None and shared.opts.data_labels.get(k, None) is None: + continue orig = shared.opts.data.get(k, None) or shared.opts.data_labels[k].default if orig == v or (type(orig) == str and os.path.splitext(orig)[0] == v): p.override_settings.pop(k, None) @@ -705,7 +732,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: modules.sd_vae.reload_vae_weights() shared.prompt_styles.apply_styles_to_extra(p) - if not shared.opts.cuda_compile: modules.sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio()) modules.sd_hijack_freeu.apply_freeu(p, shared.backend == shared.Backend.ORIGINAL) @@ -754,10 +780,13 @@ def validate_sample(tensor): return tensor if tensor.dtype == torch.bfloat16: # numpy does not support bf16 tensor = tensor.to(torch.float16) - if shared.backend == shared.Backend.ORIGINAL: - sample = 255.0 * np.moveaxis(tensor.cpu().numpy(), 0, 2) + if isinstance(tensor, torch.Tensor) and hasattr(tensor, 'detach'): + sample = tensor.detach().cpu().numpy() + elif isinstance(tensor, np.ndarray): + sample = tensor else: - sample = 255.0 * tensor + shared.log.warning(f'Unknown sample type: {type(tensor)}') + sample = 255.0 * np.moveaxis(sample, 0, 2) if shared.backend == shared.Backend.ORIGINAL else 255.0 * sample with warnings.catch_warnings(record=True) as w: cast = sample.astype(np.uint8) if len(w) > 0: @@ -770,20 +799,9 @@ def validate_sample(tensor): return cast -def process_images_inner(p: StableDiffusionProcessing) -> Processed: - """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - - if type(p.prompt) == list: - assert len(p.prompt) > 0 - else: - assert p.prompt is not None - +def process_init(p: StableDiffusionProcessing): seed = get_fixed_seed(p.seed) subseed = get_fixed_seed(p.subseed) - if shared.backend == shared.Backend.ORIGINAL: - modules.sd_hijack.model_hijack.apply_circular(p.tiling) - modules.sd_hijack.model_hijack.clear_comments() - comments = {} if type(p.prompt) == list: p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt] else: @@ -800,15 +818,32 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.all_subseeds = subseed else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] - if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and shared.backend == shared.Backend.ORIGINAL: - modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False) - if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner): - p.scripts.process(p) + + +def process_images_inner(p: StableDiffusionProcessing) -> Processed: + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + + if type(p.prompt) == list: + assert len(p.prompt) > 0 + else: + assert p.prompt is not None + + if shared.backend == shared.Backend.ORIGINAL: + modules.sd_hijack.model_hijack.apply_circular(p.tiling) + modules.sd_hijack.model_hijack.clear_comments() + comments = {} infotexts = [] output_images = [] cached_uc = [None, None] cached_c = [None, None] + process_init(p) + if os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings and shared.backend == shared.Backend.ORIGINAL: + modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False) + if p.scripts is not None and isinstance(p.scripts, modules.scripts.ScriptRunner): + p.scripts.process(p) + + def get_conds_with_caching(function, required_prompts, steps, cache): if cache[0] is not None and (required_prompts, steps) == cache[0]: return cache[1] @@ -827,6 +862,7 @@ def infotext(_inxex=0): # dummy function overriden if there are iterations with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) extra_network_data = None + debug(f'Processing inner: args={vars(p)}') for n in range(p.n_iter): p.iteration = n if shared.state.skipped: @@ -883,7 +919,7 @@ def infotext(_inxex=0): # dummy function overriden if there are iterations elif shared.backend == shared.Backend.DIFFUSERS: from modules.processing_diffusers import process_diffusers - x_samples_ddim = process_diffusers(p, p.seeds, p.prompts, p.negative_prompts) + x_samples_ddim = process_diffusers(p) else: raise ValueError(f"Unknown backend {shared.backend}") @@ -1008,7 +1044,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_force: bool = False, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs): super().__init__(**kwargs) - if devices.backend == "ipex" or shared.cmd_opts.use_openvino: + if devices.backend == "ipex" and os.environ.get('DISABLE_IPEX_1024_WA', None) is None: width_curse = bool(hr_resize_x == 1024 and self.height * (hr_resize_x / self.width) == 1024) height_curse = bool(hr_resize_y == 1024 and self.width * (hr_resize_y / self.height) == 1024) if (width_curse != height_curse) or (height_curse and width_curse): @@ -1180,16 +1216,19 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.3, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs): + def __init__(self, init_images: list = None, resize_mode: int = 0, resize_name: str = 'None', denoising_strength: float = 0.3, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, scale_by: float = 1, refiner_steps: int = 5, refiner_start: float = 0, refiner_prompt: str = '', refiner_negative: str = '', **kwargs): super().__init__(**kwargs) self.init_images = init_images self.resize_mode: int = resize_mode + self.resize_name: str = resize_name self.denoising_strength: float = denoising_strength self.image_cfg_scale: float = image_cfg_scale self.init_latent = None self.image_mask = mask self.latent_mask = None self.mask_for_overlay = None + self.mask_blur_x: int = 4 + self.mask_blur_y: int = 4 self.mask_blur = mask_blur self.inpainting_fill = inpainting_fill self.inpaint_full_res = inpaint_full_res @@ -1205,15 +1244,25 @@ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_str self.refiner_negative = refiner_negative self.enable_hr = None self.is_batch = False - self.scale_by = 1.0 + self.scale_by = scale_by self.sampler = None self.scripts = None self.script_args = [] + @property + def mask_blur(self): + mask_blur = max(self.mask_blur_x, self.mask_blur_y) + return mask_blur + + @mask_blur.setter + def mask_blur(self, value): + self.mask_blur_x = value + self.mask_blur_y = value + def init(self, all_prompts, all_seeds, all_subseeds): - if shared.backend == shared.Backend.DIFFUSERS and self.image_mask is not None: + if shared.backend == shared.Backend.DIFFUSERS and self.image_mask is not None and not self.is_control: shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.INPAINTING) - elif shared.backend == shared.Backend.DIFFUSERS and self.image_mask is None: + elif shared.backend == shared.Backend.DIFFUSERS and self.image_mask is None and not self.is_control: shared.sd_model = modules.sd_models.set_diffuser_pipe(self.sd_model, modules.sd_models.DiffusersTaskType.IMAGE_2_IMAGE) if self.sampler_name == "PLMS": @@ -1232,19 +1281,28 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: if type(image_mask) == list: image_mask = image_mask[0] - image_mask = image_mask.convert('L') + image_mask = create_binary_mask(image_mask) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) - if self.mask_blur > 0: - image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') crop_region = modules.masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = modules.masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region + mask = mask.crop(crop_region) - image_mask = images.resize_image(3, mask, self.width, self.height) + image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) @@ -1252,6 +1310,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] + latent_mask = self.latent_mask if self.latent_mask is not None else image_mask add_color_corrections = shared.opts.img2img_color_correction and self.color_corrections is None @@ -1259,15 +1318,23 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.color_corrections = [] imgs = [] unprocessed = [] + if getattr(self, 'init_images', None) is None: + return + if not isinstance(self.init_images, list): + self.init_images = [self.init_images] for img in self.init_images: + if img is None: + shared.log.warning(f"Skipping empty image: images={self.init_images}") + continue self.init_img_hash = hashlib.sha256(img.tobytes()).hexdigest()[0:8] # pylint: disable=attribute-defined-outside-init self.init_img_width = img.width # pylint: disable=attribute-defined-outside-init self.init_img_height = img.height # pylint: disable=attribute-defined-outside-init if shared.opts.save_init_img: images.save_image(img, path=shared.opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, suffix="-init-image") image = images.flatten(img, shared.opts.img2img_background_color) - if crop_region is None and self.resize_mode != 4: - image = images.resize_image(self.resize_mode, image, self.width, self.height) + if crop_region is None and self.resize_mode != 4 and self.resize_mode > 0: + if image.width != self.width or image.height != self.height: + image = images.resize_image(self.resize_mode, image, self.width, self.height, self.resize_name) self.width = image.width self.height = image.height if image_mask is not None: @@ -1275,15 +1342,15 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_masked = Image.new('RGBa', (image.width, image.height)) image_to_paste = image.convert("RGBA").convert("RGBa") image_to_mask = ImageOps.invert(self.mask_for_overlay.convert('L')) if self.mask_for_overlay is not None else None + image_to_mask = image_to_mask.resize((image.width, image.height), Image.Resampling.BILINEAR) if image_to_mask is not None else None image_masked.paste(image_to_paste, mask=image_to_mask) self.overlay_images.append(image_masked.convert('RGBA')) except Exception as e: shared.log.error(f"Failed to apply mask to image: {e}") - self.mask = image_mask # assign early for diffusers - # crop_region is not None if we are doing inpaint full res - if crop_region is not None: + if crop_region is not None: # crop_region is not None if we are doing inpaint full res image = image.crop(crop_region) - image = images.resize_image(3, image, self.width, self.height) + if image.width != self.width or image.height != self.height: + image = images.resize_image(3, image, self.width, self.height, self.resize_name) if image_mask is not None and self.inpainting_fill != 1: image = modules.masking.fill(image, latent_mask) if add_color_corrections: @@ -1304,7 +1371,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.batch_size = len(imgs) batch_images = np.array(imgs) else: - raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + raise RuntimeError(f"Incorrect number of of images={len(imgs)} expected={self.batch_size} or less") if shared.backend == shared.Backend.DIFFUSERS: return # we've already set self.init_images and self.mask and we dont need any more processing diff --git a/modules/processing_correction.py b/modules/processing_correction.py index 704805db4..4d4bd8f60 100644 --- a/modules/processing_correction.py +++ b/modules/processing_correction.py @@ -8,56 +8,81 @@ from modules import shared -debug = shared.log.info if os.environ.get('SD_HDR_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = shared.log.trace if os.environ.get('SD_HDR_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: HDR') -def soft_clamp_tensor(input_tensor, threshold=0.8, boundary=4): +def soft_clamp_tensor(tensor, threshold=0.8, boundary=4): # shrinking towards the mean; will also remove outliers - if max(abs(input_tensor.max()), abs(input_tensor.min())) < boundary or threshold == 0: - return input_tensor - channel_dim = 1 + if max(abs(tensor.max()), abs(tensor.min())) < boundary or threshold == 0: + return tensor + channel_dim = 0 threshold *= boundary - max_vals = input_tensor.max(channel_dim, keepdim=True)[0] - max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold - over_mask = input_tensor > threshold - min_vals = input_tensor.min(channel_dim, keepdim=True)[0] - min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold - under_mask = input_tensor < -threshold - debug(f'HDE soft clamp: threshold={threshold} boundary={boundary}') - input_tensor = torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor)) - return input_tensor - - -def center_tensor(input_tensor, channel_shift=1.0, full_shift=1.0, channels=[0, 1, 2, 3]): # pylint: disable=dangerous-default-value # noqa: B006 + max_vals = tensor.max(channel_dim, keepdim=True)[0] + max_replace = ((tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold + over_mask = tensor > threshold + min_vals = tensor.min(channel_dim, keepdim=True)[0] + min_replace = ((tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold + under_mask = tensor < -threshold + tensor = torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, tensor)) + debug(f'HDR soft clamp: threshold={threshold} boundary={boundary} shape={tensor.shape}') + return tensor + + +def center_tensor(tensor, channel_shift=1.0, full_shift=1.0, channels=[0, 1, 2, 3]): # pylint: disable=dangerous-default-value # noqa: B006 if channel_shift == 0 and full_shift == 0: - return input_tensor + return tensor means = [] for channel in channels: - means.append(input_tensor[0, channel].mean()) - input_tensor[0, channel] -= means[-1] * channel_shift - debug(f'HDR center: channel-shift{channel_shift} full-shift={full_shift} means={torch.stack(means)}') - input_tensor = input_tensor - input_tensor.mean() * full_shift - return input_tensor + means.append(tensor[0, channel].mean()) + # tensor[0, channel] -= means[-1] * channel_shift + tensor[channel] -= means[-1] * channel_shift + tensor = tensor - tensor.mean() * full_shift + debug(f'HDR center: channel-shift={channel_shift} full-shift={full_shift} means={torch.stack(means)} shape={tensor.shape}') + return tensor -def maximize_tensor(input_tensor, boundary=1.0, channels=[0, 1, 2]): # pylint: disable=dangerous-default-value # noqa: B006 +def maximize_tensor(tensor, boundary=1.0, _channels=[0, 1, 2]): # pylint: disable=dangerous-default-value # noqa: B006 if boundary == 1.0: - return input_tensor + return tensor boundary *= 4 - min_val = input_tensor.min() - max_val = input_tensor.max() + min_val = tensor.min() + max_val = tensor.max() normalization_factor = boundary / max(abs(min_val), abs(max_val)) - input_tensor[0, channels] *= normalization_factor - debug(f'HDR maximize: boundary={boundary} min={min_val} max={max_val} factor={normalization_factor}') - return input_tensor + # tensor[0, channels] *= normalization_factor + tensor *= normalization_factor + debug(f'HDR maximize: boundary={boundary} min={min_val} max={max_val} factor={normalization_factor} shape={tensor.shape}') + return tensor -def correction_callback(p, timestep, kwags): +def correction(p, timestep, latent): if timestep > 950 and p.hdr_clamp: - kwags["latents"] = soft_clamp_tensor(kwags["latents"], threshold=p.hdr_threshold, boundary=p.hdr_boundary) + p.extra_generation_params["HDR clamp"] = f'{p.hdr_threshold}/{p.hdr_boundary}' + latent = soft_clamp_tensor(latent, threshold=p.hdr_threshold, boundary=p.hdr_boundary) if timestep > 700 and p.hdr_center: - kwags["latents"] = center_tensor(kwags["latents"], channel_shift=p.hdr_channel_shift, full_shift=p.hdr_full_shift) + p.extra_generation_params["HDR center"] = f'{p.hdr_channel_shift}/{p.hdr_full_shift}' + latent = center_tensor(latent, channel_shift=p.hdr_channel_shift, full_shift=p.hdr_full_shift) if timestep > 1 and timestep < 100 and p.hdr_maximize: - kwags["latents"] = center_tensor(kwags["latents"], channel_shift=p.hdr_max_center, full_shift=1.0) - kwags["latents"] = maximize_tensor(kwags["latents"], boundary=p.hdr_max_boundry) - return kwags + p.extra_generation_params["HDR max"] = f'{p.hdr_max_center}/p.hdr_max_boundry' + latent = center_tensor(latent, channel_shift=p.hdr_max_center, full_shift=1.0) + latent = maximize_tensor(latent, boundary=p.hdr_max_boundry) + return latent + + +def correction_callback(p, timestep, kwargs): + if not p.hdr_clamp and not p.hdr_center and not p.hdr_maximize: + return kwargs + latents = kwargs["latents"] + # debug(f'HDR correction: latents={latents.shape}') + if len(latents.shape) == 4: # standard batched latent + for i in range(latents.shape[0]): + latents[i] = correction(p, timestep, latents[i]) + elif len(latents.shape) == 5 and latents.shape[0] == 1: # probably animatediff + latents = latents.squeeze(0).permute(1, 0, 2, 3) + for i in range(latents.shape[0]): + latents[i] = correction(p, timestep, latents[i]) + latents = latents.permute(1, 0, 2, 3).unsqueeze(0) + else: + shared.log.debug(f'HDR correction: unknown latent shape {latents.shape}') + kwargs["latents"] = latents + return kwargs diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index e90cd0126..66877a83a 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -9,21 +9,32 @@ import modules.shared as shared import modules.sd_samplers as sd_samplers import modules.sd_models as sd_models -import modules.sd_vae as sd_vae -import modules.taesd.sd_vae_taesd as sd_vae_taesd import modules.images as images import modules.errors as errors from modules.processing import StableDiffusionProcessing, create_random_tensors import modules.prompt_parser_diffusers as prompt_parser_diffusers from modules.sd_hijack_hypertile import hypertile_set from modules.processing_correction import correction_callback +from modules.processing_vae import vae_encode, vae_decode -def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_prompts): +debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: DIFFUSERS') +debug_steps = shared.log.trace if os.environ.get('SD_STEPS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug_steps('Trace: STEPS') + + +def process_diffusers(p: StableDiffusionProcessing): + debug(f'Process diffusers args: {vars(p)}') results = [] - is_refiner_enabled = p.enable_hr and p.refiner_steps > 0 and p.refiner_start > 0 and p.refiner_start < 1 and shared.sd_refiner is not None - if hasattr(p, 'init_images') and len(p.init_images) > 0: + def is_txt2img(): + return sd_models.get_diffusers_task(shared.sd_model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE + + def is_refiner_enabled(): + return p.enable_hr and p.refiner_steps > 0 and p.refiner_start > 0 and p.refiner_start < 1 and shared.sd_refiner is not None + + if getattr(p, 'init_images', None) is not None and len(p.init_images) > 0: tgt_width, tgt_height = 8 * math.ceil(p.init_images[0].width / 8), 8 * math.ceil(p.init_images[0].height / 8) if p.init_images[0].width != tgt_width or p.init_images[0].height != tgt_height: shared.log.debug(f'Resizing init images: original={p.init_images[0].width}x{p.init_images[0].height} target={tgt_width}x{tgt_height}') @@ -37,6 +48,10 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro p.mask_for_overlay = images.resize_image(1, p.mask_for_overlay, tgt_width, tgt_height, upscaler_name=None) def hires_resize(latents): # input=latents output=pil + if not torch.is_tensor(latents): + shared.log.warning('Hires: input is not tensor') + first_pass_images = vae_decode(latents=latents, model=shared.sd_model, full_quality=p.full_quality, output_type='pil') + return first_pass_images latent_upscaler = shared.latent_upscale_modes.get(p.hr_upscaler, None) shared.log.info(f'Hires: upscaler={p.hr_upscaler} width={p.hr_upscale_to_x} height={p.hr_upscale_to_y} images={latents.shape[0]}') if latent_upscaler is not None: @@ -57,11 +72,12 @@ def save_intermediate(latents, suffix): info=create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i) decoded = vae_decode(latents=latents, model=shared.sd_model, output_type='pil', full_quality=p.full_quality) for j in range(len(decoded)): - images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=seeds[i], prompt=prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix=suffix) + images.save_image(decoded[j], path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix=suffix) - def diffusers_callback_legacy(step: int, _timestep: int, latents: torch.FloatTensor): + def diffusers_callback_legacy(step: int, timestep: int, latents: torch.FloatTensor): shared.state.sampling_step = step shared.state.current_latent = latents + latents = correction_callback(p, timestep, {'latents': latents}) if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') if shared.state.paused: @@ -84,107 +100,21 @@ def diffusers_callback(_pipe, step: int, timestep: int, kwargs: dict): if kwargs.get('latents', None) is None: return kwargs kwargs = correction_callback(p, timestep, kwargs) + if p.scheduled_prompt and hasattr(kwargs, 'prompt_embeds') and hasattr(kwargs, 'negative_prompt_embeds'): + try: + i = (step + 1) % len(p.prompt_embeds) + kwargs["prompt_embeds"] = p.prompt_embeds[i][0:1].repeat(1, kwargs["prompt_embeds"].shape[0], 1).view( + kwargs["prompt_embeds"].shape[0], kwargs["prompt_embeds"].shape[1], -1) + j = (step + 1) % len(p.negative_embeds) + kwargs["negative_prompt_embeds"] = p.negative_embeds[j][0:1].repeat(1, kwargs["negative_prompt_embeds"].shape[0], 1).view( + kwargs["negative_prompt_embeds"].shape[0], kwargs["negative_prompt_embeds"].shape[1], -1) + except Exception as e: + shared.log.debug(f"Callback: {e}") shared.state.current_latent = kwargs['latents'] if shared.cmd_opts.profile and shared.profiler is not None: shared.profiler.step() return kwargs - def full_vae_decode(latents, model): - t0 = time.time() - if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): - shared.log.debug('Moving to CPU: model=UNet') - unet_device = model.unet.device - model.unet.to(devices.cpu) - devices.torch_gc() - if not shared.cmd_opts.lowvram and not shared.opts.diffusers_seq_cpu_offload and hasattr(model, 'vae'): - model.vae.to(devices.device) - latents.to(model.vae.device) - - upcast = (model.vae.dtype == torch.float16) and getattr(model.vae.config, 'force_upcast', False) and hasattr(model, 'upcast_vae') - if upcast: # this is done by diffusers automatically if output_type != 'latent' - model.upcast_vae() - latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype) - - decoded = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0] - if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): - model.unet.to(unet_device) - t1 = time.time() - shared.log.debug(f'VAE decode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={upcast} images={latents.shape[0]} latents={latents.shape} time={round(t1-t0, 3)}') - return decoded - - def full_vae_encode(image, model): - shared.log.debug(f'VAE encode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}') - if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): - shared.log.debug('Moving to CPU: model=UNet') - unet_device = model.unet.device - model.unet.to(devices.cpu) - devices.torch_gc() - if not shared.cmd_opts.lowvram and not shared.opts.diffusers_seq_cpu_offload and hasattr(model, 'vae'): - model.vae.to(devices.device) - encoded = model.vae.encode(image.to(model.vae.device, model.vae.dtype)).latent_dist.sample() - if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): - model.unet.to(unet_device) - return encoded - - def taesd_vae_decode(latents): - shared.log.debug(f'VAE decode: name=TAESD images={len(latents)} latents={latents.shape}') - if len(latents) == 0: - return [] - decoded = torch.zeros((len(latents), 3, latents.shape[2] * 8, latents.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device) - for i in range(latents.shape[0]): - decoded[i] = sd_vae_taesd.decode(latents[i]) - return decoded - - def taesd_vae_encode(image): - shared.log.debug(f'VAE encode: name=TAESD image={image.shape}') - encoded = sd_vae_taesd.encode(image) - return encoded - - def vae_decode(latents, model, output_type='np', full_quality=True): - t0 = time.time() - prev_job = shared.state.job - shared.state.job = 'vae' - if not torch.is_tensor(latents): # already decoded - return latents - if latents.shape[0] == 0: - shared.log.error(f'VAE nothing to decode: {latents.shape}') - return [] - if shared.state.interrupted or shared.state.skipped: - return [] - if not hasattr(model, 'vae'): - shared.log.error('VAE not found in model') - return [] - if latents.shape[0] == 4 and latents.shape[1] != 4: # likely animatediff latent - latents = latents.permute(1, 0, 2, 3) - if len(latents.shape) == 3: # lost a batch dim in hires - latents = latents.unsqueeze(0) - if full_quality: - decoded = full_vae_decode(latents=latents, model=shared.sd_model) - else: - decoded = taesd_vae_decode(latents=latents) - # TODO validate decoded sample diffusers - # decoded = validate_sample(decoded) - imgs = model.image_processor.postprocess(decoded, output_type=output_type) - shared.state.job = prev_job - if shared.cmd_opts.profile: - t1 = time.time() - shared.log.debug(f'Profile: VAE decode: {t1-t0:.2f}') - return imgs - - def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable - if shared.state.interrupted or shared.state.skipped: - return [] - if not hasattr(model, 'vae'): - shared.log.error('VAE not found in model') - return [] - tensor = TF.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae) - if full_quality: - tensor = tensor * 2 - 1 - latents = full_vae_encode(image=tensor, model=shared.sd_model) - else: - latents = taesd_vae_encode(image=tensor) - return latents - def fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2): if type(prompts) is str: prompts = [prompts] @@ -211,10 +141,11 @@ def task_specific_kwargs(model): is_img2img_model = bool('Zero123' in shared.sd_model.__class__.__name__) if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE and not is_img2img_model: p.ops.append('txt2img') - task_args = { - 'height': 8 * math.ceil(p.height / 8), - 'width': 8 * math.ceil(p.width / 8), - } + if hasattr(p, 'width') and hasattr(p, 'height'): + task_args = { + 'width': 8 * math.ceil(p.width / 8), + 'height': 8 * math.ceil(p.height / 8), + } elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE or is_img2img_model) and len(getattr(p, 'init_images' ,[])) > 0: p.ops.append('img2img') task_args = { @@ -224,8 +155,8 @@ def task_specific_kwargs(model): elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INSTRUCT and len(getattr(p, 'init_images' ,[])) > 0: p.ops.append('instruct') task_args = { - 'height': 8 * math.ceil(p.height / 8), - 'width': 8 * math.ceil(p.width / 8), + 'width': 8 * math.ceil(p.width / 8) if hasattr(p, 'width') else None, + 'height': 8 * math.ceil(p.height / 8) if hasattr(p, 'height') else None, 'image': p.init_images, 'strength': p.denoising_strength, } @@ -233,34 +164,17 @@ def task_specific_kwargs(model): p.ops.append('inpaint') if getattr(p, 'mask', None) is None: p.mask = TF.to_pil_image(torch.ones_like(TF.to_tensor(p.init_images[0]))).convert("L") + p.mask = shared.sd_model.mask_processor.blur(p.mask, blur_factor=p.mask_blur) width = 8 * math.ceil(p.init_images[0].width / 8) height = 8 * math.ceil(p.init_images[0].height / 8) - # option-1: use images as inputs task_args = { 'image': p.init_images, 'mask_image': p.mask, 'strength': p.denoising_strength, 'height': height, 'width': width, + # 'padding_mask_crop': p.inpaint_full_res_padding # done back in main processing method } - """ # option-2: preprocess images into latents using diffusers - vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) - image_processor = diffusers.image_processor.VaeImageProcessor(vae_scale_factor=vae_scale_factor) - mask_processor = diffusers.image_processor.VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True) - init_image = image_processor.preprocess(p.init_images[0], width=width, height=height) - mask_image = mask_processor.preprocess(p.mask, width=width, height=height) - task_args = {"image": p.init_images, "mask_image": p.mask, "strength": p.denoising_strength, "height": height, "width": width} - """ - """ # option-2: manually assemble masked image latents - masked_image_latents = [] - mask_image = TF.to_tensor(p.mask) - for init_image in p.init_images: - init_image = TF.to_tensor(p.init_images[0]) - masked_image = init_image * (mask_image > 0.5) - masked_image_latents.append(torch.cat([masked_image, mask_image], dim=0)) - masked_image_latents = torch.stack(masked_image_latents, dim=0).to(shared.device) - task_args = {"image": p.init_images, "mask_image": mask_image, "masked_image_latents": masked_image_latents, "strength": p.denoising_strength, "height": height, "width": width} - """ if model.__class__.__name__ == 'LatentConsistencyModelPipeline' and hasattr(p, 'init_images') and len(p.init_images) > 0: p.ops.append('lcm') init_latents = [vae_encode(image, model=shared.sd_model, full_quality=p.full_quality).squeeze(dim=0) for image in p.init_images] @@ -269,9 +183,10 @@ def task_specific_kwargs(model): init_latent = (1 - p.denoising_strength) * init_latent + init_noise task_args = { 'latents': init_latent.to(model.dtype), - 'width': p.width, - 'height': p.height, + 'width': p.width if hasattr(p, 'width') else None, + 'height': p.height if hasattr(p, 'height') else None, } + debug(f'Diffusers task specific args: {task_args}') return task_args def set_pipeline_args(model, prompts: list, negative_prompts: list, prompts_2: typing.Optional[list]=None, negative_prompts_2: typing.Optional[list]=None, desc:str='', **kwargs): @@ -281,48 +196,46 @@ def set_pipeline_args(model, prompts: list, negative_prompts: list, prompts_2: t args = {} signature = inspect.signature(type(model).__call__) possible = signature.parameters.keys() - generator_device = devices.cpu if shared.opts.diffusers_generator_device == "cpu" else shared.device - generator = [torch.Generator(generator_device).manual_seed(s) for s in seeds] - prompt_embed = None - pooled = None - negative_embed = None - negative_pooled = None + debug(f'Diffusers pipeline possible: {possible}') + if shared.opts.diffusers_generator_device == "Unset": + generator_device = None + generator = None + else: + generator_device = devices.cpu if shared.opts.diffusers_generator_device == "CPU" else shared.device + generator = [torch.Generator(generator_device).manual_seed(s) for s in p.seeds] prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2) parser = 'Fixed attention' if shared.opts.prompt_attention != 'Fixed attention' and 'StableDiffusion' in model.__class__.__name__: try: - prompt_embed, pooled, negative_embed, negative_pooled = prompt_parser_diffusers.encode_prompts(model, prompts, negative_prompts, kwargs.pop("clip_skip", None)) + prompt_parser_diffusers.encode_prompts(model, p, prompts, negative_prompts, kwargs.get("num_inference_steps", 1), 0, kwargs.pop("clip_skip", None)) + # prompt_embed, pooled, negative_embed, negative_pooled = , , , , parser = shared.opts.prompt_attention except Exception as e: shared.log.error(f'Prompt parser encode: {e}') if os.environ.get('SD_PROMPT_DEBUG', None) is not None: errors.display(e, 'Prompt parser encode') if 'prompt' in possible: - if hasattr(model, 'text_encoder') and 'prompt_embeds' in possible and prompt_embed is not None: - if type(pooled) == list: - pooled = pooled[0] - if type(negative_pooled) == list: - negative_pooled = negative_pooled[0] - args['prompt_embeds'] = prompt_embed + if hasattr(model, 'text_encoder') and 'prompt_embeds' in possible and len(p.prompt_embeds) > 0 and p.prompt_embeds[0] is not None: + args['prompt_embeds'] = p.prompt_embeds[0] if 'XL' in model.__class__.__name__: - args['pooled_prompt_embeds'] = pooled + args['pooled_prompt_embeds'] = p.positive_pooleds[0] else: args['prompt'] = prompts if 'negative_prompt' in possible: - if hasattr(model, 'text_encoder') and 'negative_prompt_embeds' in possible and negative_embed is not None: - args['negative_prompt_embeds'] = negative_embed + if hasattr(model, 'text_encoder') and 'negative_prompt_embeds' in possible and len(p.negative_embeds) > 0 and p.negative_embeds[0] is not None: + args['negative_prompt_embeds'] = p.negative_embeds[0] if 'XL' in model.__class__.__name__: - args['negative_pooled_prompt_embeds'] = negative_pooled + args['negative_pooled_prompt_embeds'] = p.negative_pooleds[0] else: args['negative_prompt'] = negative_prompts if hasattr(model, 'scheduler') and hasattr(model.scheduler, 'noise_sampler_seed') and hasattr(model.scheduler, 'noise_sampler'): model.scheduler.noise_sampler = None # noise needs to be reset instead of using cached values - model.scheduler.noise_sampler_seed = seeds[0] # some schedulers have internal noise generator and do not use pipeline generator + model.scheduler.noise_sampler_seed = p.seeds[0] # some schedulers have internal noise generator and do not use pipeline generator if 'noise_sampler_seed' in possible: - args['noise_sampler_seed'] = seeds[0] + args['noise_sampler_seed'] = p.seeds[0] if 'guidance_scale' in possible: args['guidance_scale'] = p.cfg_scale - if 'generator' in possible: + if 'generator' in possible and generator is not None: args['generator'] = generator if 'output_type' in possible: args['output_type'] = 'np' @@ -332,7 +245,10 @@ def set_pipeline_args(model, prompts: list, negative_prompts: list, prompts_2: t args['callback'] = diffusers_callback_legacy elif 'callback_on_step_end_tensor_inputs' in possible: args['callback_on_step_end'] = diffusers_callback - args['callback_on_step_end_tensor_inputs'] = ['latents'] + if 'prompt_embeds' in possible and 'negative_prompt_embeds' in possible: + args['callback_on_step_end_tensor_inputs'] = ['latents', 'prompt_embeds', 'negative_prompt_embeds'] + else: + args['callback_on_step_end_tensor_inputs'] = ['latents'] for arg in kwargs: if arg in possible: # add kwargs args[arg] = kwargs[arg] @@ -344,10 +260,14 @@ def set_pipeline_args(model, prompts: list, negative_prompts: list, prompts_2: t if arg in possible: args[arg] = task_kwargs[arg] task_args = getattr(p, 'task_args', {}) + debug(f'Diffusers task args: {task_args}') for k, v in task_args.items(): - args[k] = v + if k in possible: + args[k] = v + else: + debug(f'Diffusers unknown task args: {k}={v}') - hypertile_set(p, hr=len(getattr(p, 'init_images', []))) + hypertile_set(p, hr=len(getattr(p, 'init_images', [])) > 0) clean = args.copy() clean.pop('callback', None) clean.pop('callback_steps', None) @@ -380,22 +300,23 @@ def set_pipeline_args(model, prompts: list, negative_prompts: list, prompts_2: t shared.log.debug(f'Diffuser pipeline: {model.__class__.__name__} task={sd_models.get_diffusers_task(model)} set={clean}') if p.hdr_clamp or p.hdr_center or p.hdr_maximize: txt = 'HDR:' - txt += f' Clamp threshold={p.hdr_threshold} boundary={p.hdr_boundary}' if p.hdr_clamp else 'Clamp off' - txt += f' Center channel-shift={p.hdr_channel_shift} full-shift={p.hdr_full_shift}' if p.hdr_center else 'Center off' - txt += f' Maximize boundary={p.hdr_max_boundry} center={p.hdr_max_center}' if p.hdr_maximize else 'Maximize off' + txt += f' Clamp threshold={p.hdr_threshold} boundary={p.hdr_boundary}' if p.hdr_clamp else ' Clamp off' + txt += f' Center channel-shift={p.hdr_channel_shift} full-shift={p.hdr_full_shift}' if p.hdr_center else ' Center off' + txt += f' Maximize boundary={p.hdr_max_boundry} center={p.hdr_max_center}' if p.hdr_maximize else ' Maximize off' shared.log.debug(txt) # components = [{ k: getattr(v, 'device', None) } for k, v in model.components.items()] # shared.log.debug(f'Diffuser pipeline components: {components}') if shared.cmd_opts.profile: t1 = time.time() shared.log.debug(f'Profile: pipeline args: {t1-t0:.2f}') + debug(f'Diffusers pipeline args: {args}') return args def recompile_model(hires=False): if shared.opts.cuda_compile and shared.opts.cuda_compile_backend != 'none': if shared.opts.cuda_compile_backend == "openvino_fx": - compile_height = p.height if not hires else p.hr_upscale_to_y - compile_width = p.width if not hires else p.hr_upscale_to_x + compile_height = p.height if not hires and hasattr(p, 'height') else p.hr_upscale_to_y + compile_width = p.width if not hires and hasattr(p, 'width') else p.hr_upscale_to_x if (shared.compiled_model_state is None or (not shared.compiled_model_state.first_pass and (shared.compiled_model_state.height != compile_height @@ -405,22 +326,33 @@ def recompile_model(hires=False): shared.log.info("OpenVINO: Recompiling base model") sd_models.unload_model_weights(op='model') sd_models.reload_model_weights(op='model') - if is_refiner_enabled: + if is_refiner_enabled(): shared.log.info("OpenVINO: Recompiling refiner") sd_models.unload_model_weights(op='refiner') sd_models.reload_model_weights(op='refiner') shared.compiled_model_state.height = compile_height shared.compiled_model_state.width = compile_width shared.compiled_model_state.batch_size = p.batch_size + + # Downcast UNET after OpenVINO compile + def downcast_openvino(op="base"): + if shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx": + if shared.compiled_model_state.first_pass and op == "base": shared.compiled_model_state.first_pass = False - else: - pass #Can be implemented for TensorRT or Olive - else: - pass #Do nothing if compile is disabled + if hasattr(shared.sd_model, "unet"): + shared.sd_model.unet.to(dtype=torch.float8_e4m3fn) + devices.torch_gc(force=True) + if shared.compiled_model_state.first_pass_refiner and op == "refiner": + shared.compiled_model_state.first_pass_refiner = False + if hasattr(shared.sd_refiner, "unet"): + shared.sd_refiner.unet.to(dtype=torch.float8_e4m3fn) + devices.torch_gc(force=True) def update_sampler(sd_model, second_pass=False): sampler_selection = p.latent_sampler if second_pass else p.sampler_name # is_karras_compatible = sd_model.__class__.__init__.__annotations__.get("scheduler", None) == diffusers.schedulers.scheduling_utils.KarrasDiffusionSchedulers + if sd_model.__class__.__name__ in ['AmusedPipeline']: + return # models with their own schedulers if hasattr(sd_model, 'scheduler') and sampler_selection != 'Default': sampler = sd_samplers.all_samplers_map.get(sampler_selection, None) if sampler is None: @@ -430,7 +362,7 @@ def update_sampler(sd_model, second_pass=False): # p.extra_generation_params['Sampler options'] = '' if len(getattr(p, 'init_images', [])) > 0: - while len(p.init_images) < len(prompts): + while len(p.init_images) < len(p.prompts): p.init_images.append(p.init_images[-1]) if shared.state.interrupted or shared.state.skipped: @@ -439,33 +371,41 @@ def update_sampler(sd_model, second_pass=False): if shared.opts.diffusers_move_base and not getattr(shared.sd_model, 'has_accelerate', False): shared.sd_model.to(devices.device) - is_img2img = bool(sd_models.get_diffusers_task(shared.sd_model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE or sd_models.get_diffusers_task(shared.sd_model) == sd_models.DiffusersTaskType.INPAINTING) - use_refiner_start = bool(is_refiner_enabled and not p.is_hr_pass and not is_img2img and p.refiner_start > 0 and p.refiner_start < 1) - use_denoise_start = bool(is_img2img and p.refiner_start > 0 and p.refiner_start < 1) + # pipeline type is set earlier in processing, but check for sanity + has_images = len(getattr(p, 'init_images' ,[])) > 0 or getattr(p, 'is_control', False) is True + if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE and not has_images: + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) # reset pipeline + if hasattr(shared.sd_model, 'unet') and hasattr(shared.sd_model.unet, 'config') and hasattr(shared.sd_model.unet.config, 'in_channels') and shared.sd_model.unet.config.in_channels == 9: + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline + if len(getattr(p, 'init_images' ,[])) == 0: + p.init_images = [TF.to_pil_image(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))] + + use_refiner_start = is_txt2img() and is_refiner_enabled() and not p.is_hr_pass and p.refiner_start > 0 and p.refiner_start < 1 + use_denoise_start = not is_txt2img() and p.refiner_start > 0 and p.refiner_start < 1 def calculate_base_steps(): - if is_img2img: + if not is_txt2img(): if use_denoise_start and shared.sd_model_type == 'sdxl': steps = p.steps // (1 - p.refiner_start) - else: + elif p.denoising_strength > 0: steps = (p.steps // p.denoising_strength) + 1 + else: + steps = p.steps elif use_refiner_start and shared.sd_model_type == 'sdxl': steps = (p.steps // p.refiner_start) + 1 else: steps = p.steps - - if os.environ.get('SD_STEPS_DEBUG', None) is not None: - shared.log.debug(f'Steps: type=base input={p.steps} output={steps} refiner={use_refiner_start}') + debug_steps(f'Steps: type=base input={p.steps} output={steps} task={sd_models.get_diffusers_task(shared.sd_model)} refiner={use_refiner_start} denoise={p.denoising_strength} model={shared.sd_model_type}') return max(2, int(steps)) def calculate_hires_steps(): if p.hr_second_pass_steps > 0: steps = (p.hr_second_pass_steps // p.denoising_strength) + 1 - else: + elif p.denoising_strength > 0: steps = (p.steps // p.denoising_strength) + 1 - - if os.environ.get('SD_STEPS_DEBUG', None) is not None: - shared.log.debug(f'Steps: type=hires input={p.hr_second_pass_steps} output={steps} denoise={p.denoising_strength}') + else: + steps = 0 + debug_steps(f'Steps: type=hires input={p.hr_second_pass_steps} output={steps} denoise={p.denoising_strength} model={shared.sd_model_type}') return max(2, int(steps)) def calculate_refiner_steps(): @@ -473,29 +413,22 @@ def calculate_refiner_steps(): if p.refiner_start > 0 and p.refiner_start < 1: #steps = p.refiner_steps // (1 - p.refiner_start) # SDXL with denoise strenght steps = (p.refiner_steps // (1 - p.refiner_start) // 2) + 1 - else: + elif p.denoising_strength > 0: steps = (p.refiner_steps // p.denoising_strength) + 1 + else: + steps = 0 else: #steps = p.refiner_steps # SD 1.5 with denoise strenght steps = (p.refiner_steps * 1.25) + 1 - - if os.environ.get('SD_STEPS_DEBUG', None) is not None: - shared.log.debug(f'Steps: type=refiner input={p.refiner_steps} output={steps} start={p.refiner_start} denoise={p.denoising_strength}') + debug_steps(f'Steps: type=refiner input={p.refiner_steps} output={steps} start={p.refiner_start} denoise={p.denoising_strength}') return max(2, int(steps)) - # pipeline type is set earlier in processing, but check for sanity - if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE and len(getattr(p, 'init_images' ,[])) == 0: - shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) # reset pipeline - if hasattr(shared.sd_model, 'unet') and hasattr(shared.sd_model.unet, 'config') and hasattr(shared.sd_model.unet.config, 'in_channels') and shared.sd_model.unet.config.in_channels == 9: - shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline - if len(getattr(p, 'init_images' ,[])) == 0: - p.init_images = [TF.to_pil_image(torch.rand((3, p.height, p.width)))] base_args = set_pipeline_args( model=shared.sd_model, - prompts=prompts, - negative_prompts=negative_prompts, - prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else prompts, - negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else negative_prompts, + prompts=p.prompts, + negative_prompts=p.negative_prompts, + prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, + negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, num_inference_steps=calculate_base_steps(), eta=shared.opts.scheduler_eta, guidance_scale=p.cfg_scale, @@ -510,15 +443,20 @@ def calculate_refiner_steps(): update_sampler(shared.sd_model) shared.state.sampling_steps = base_args['num_inference_steps'] p.extra_generation_params['Pipeline'] = shared.sd_model.__class__.__name__ - p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1 else None + if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1: + p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta try: t0 = time.time() output = shared.sd_model(**base_args) # pylint: disable=not-callable + downcast_openvino(op="base") if shared.cmd_opts.profile: t1 = time.time() shared.log.debug(f'Profile: pipeline call: {t1-t0:.2f}') if not hasattr(output, 'images') and hasattr(output, 'frames'): - shared.log.debug(f'Generated: frames={len(output.frames[0])}') + if hasattr(output.frames[0], 'shape'): + shared.log.debug(f'Generated: frames={output.frames[0].shape[1]}') + else: + shared.log.debug(f'Generated: frames={len(output.frames[0])}') output.images = output.frames[0] except AssertionError as e: shared.log.info(e) @@ -546,7 +484,7 @@ def calculate_refiner_steps(): if p.is_hr_pass: p.init_hr() prev_job = shared.state.job - if p.width != p.hr_upscale_to_x or p.height != p.hr_upscale_to_y: + if hasattr(p, 'height') and hasattr(p, 'width') and (p.width != p.hr_upscale_to_x or p.height != p.hr_upscale_to_y): p.ops.append('upscale') if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'): save_intermediate(latents=output.images, suffix="-before-hires") @@ -559,10 +497,10 @@ def calculate_refiner_steps(): update_sampler(shared.sd_model, second_pass=True) hires_args = set_pipeline_args( model=shared.sd_model, - prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else prompts, - negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else negative_prompts, - prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else prompts, - negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else negative_prompts, + prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, + negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, + prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts, + negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts, num_inference_steps=calculate_hires_steps(), eta=shared.opts.scheduler_eta, guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale, @@ -577,6 +515,7 @@ def calculate_refiner_steps(): shared.state.sampling_steps = hires_args['num_inference_steps'] try: output = shared.sd_model(**hires_args) # pylint: disable=not-callable + downcast_openvino(op="base") except AssertionError as e: shared.log.info(e) p.init_images = [] @@ -585,7 +524,7 @@ def calculate_refiner_steps(): p.is_hr_pass = False # optional refiner pass or decode - if is_refiner_enabled: + if is_refiner_enabled(): prev_job = shared.state.job shared.state.job = 'refine' shared.state.job_count +=1 @@ -617,8 +556,8 @@ def calculate_refiner_steps(): output_type = 'np' refiner_args = set_pipeline_args( model=shared.sd_refiner, - prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else prompts[i], - negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else negative_prompts[i], + prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts[i], + negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts[i], num_inference_steps=calculate_refiner_steps(), eta=shared.opts.scheduler_eta, # strength=p.denoising_strength, @@ -634,7 +573,9 @@ def calculate_refiner_steps(): ) shared.state.sampling_steps = refiner_args['num_inference_steps'] try: + shared.sd_refiner.register_to_config(requires_aesthetics_score=shared.opts.diffusers_aesthetics_score) refiner_output = shared.sd_refiner(**refiner_args) # pylint: disable=not-callable + downcast_openvino(op="refiner") except AssertionError as e: shared.log.info(e) @@ -652,9 +593,16 @@ def calculate_refiner_steps(): p.is_refiner_pass = False # final decode since there is no refiner - if not is_refiner_enabled: - if output is not None and output.images is not None and len(output.images) > 0: - results = vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality) + if not is_refiner_enabled(): + if output is not None: + if not hasattr(output, 'images') and hasattr(output, 'frames'): + shared.log.debug(f'Generated: frames={len(output.frames[0])}') + output.images = output.frames[0] + if output.images is not None and len(output.images) > 0: + results = vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality) + else: + shared.log.warning('Processing returned no results') + results = [] else: shared.log.warning('Processing returned no results') results = [] diff --git a/modules/processing_vae.py b/modules/processing_vae.py new file mode 100644 index 000000000..bd9425de7 --- /dev/null +++ b/modules/processing_vae.py @@ -0,0 +1,147 @@ +import os +import time +import torch +import torchvision.transforms.functional as TF +from modules import shared, devices, sd_vae +import modules.taesd.sd_vae_taesd as sd_vae_taesd + + +debug = shared.log.trace if os.environ.get('SD_VAE_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: VAE') + + +def create_latents(image, p, dtype=None, device=None): + from modules.processing import create_random_tensors + from PIL import Image + if image is None: + return image + elif isinstance(image, Image.Image): + latents = vae_encode(image, model=shared.sd_model, full_quality=p.full_quality) + elif isinstance(image, list): + latents = [vae_encode(i, model=shared.sd_model, full_quality=p.full_quality).squeeze(dim=0) for i in image] + latents = torch.stack(latents, dim=0).to(shared.device) + else: + shared.log.warning(f'Latents: input type: {type(image)} {image}') + return image + noise = p.denoising_strength * create_random_tensors(latents.shape[1:], seeds=p.all_seeds, subseeds=p.all_subseeds, subseed_strength=p.subseed_strength, p=p) + latents = (1 - p.denoising_strength) * latents + noise + if dtype is not None: + latents = latents.to(dtype=dtype) + if device is not None: + latents = latents.to(device=device) + return latents + + +def full_vae_decode(latents, model): + t0 = time.time() + if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): + shared.log.debug('Moving to CPU: model=UNet') + unet_device = model.unet.device + model.unet.to(devices.cpu) + devices.torch_gc() + if not shared.cmd_opts.lowvram and not shared.opts.diffusers_seq_cpu_offload and hasattr(model, 'vae'): + model.vae.to(devices.device) + latents.to(model.vae.device) + + upcast = (model.vae.dtype == torch.float16) and getattr(model.vae.config, 'force_upcast', False) and hasattr(model, 'upcast_vae') + if upcast: # this is done by diffusers automatically if output_type != 'latent' + model.upcast_vae() + latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype) + + decoded = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0] + + # Downcast VAE after OpenVINO compile + if shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx" and shared.compiled_model_state.first_pass_vae: + shared.compiled_model_state.first_pass_vae = False + if hasattr(shared.sd_model, "vae"): + model.vae.to(dtype=torch.float8_e4m3fn) + devices.torch_gc(force=True) + + if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): + model.unet.to(unet_device) + t1 = time.time() + debug(f'VAE decode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={upcast} images={latents.shape[0]} latents={latents.shape} time={round(t1-t0, 3)}') + return decoded + + +def full_vae_encode(image, model): + debug(f'VAE encode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}') + if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): + debug('Moving to CPU: model=UNet') + unet_device = model.unet.device + model.unet.to(devices.cpu) + devices.torch_gc() + if not shared.cmd_opts.lowvram and not shared.opts.diffusers_seq_cpu_offload and hasattr(model, 'vae'): + model.vae.to(devices.device) + encoded = model.vae.encode(image.to(model.vae.device, model.vae.dtype)).latent_dist.sample() + if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): + model.unet.to(unet_device) + return encoded + + +def taesd_vae_decode(latents): + debug(f'VAE decode: name=TAESD images={len(latents)} latents={latents.shape}') + if len(latents) == 0: + return [] + decoded = torch.zeros((len(latents), 3, latents.shape[2] * 8, latents.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device) + for i in range(latents.shape[0]): + decoded[i] = sd_vae_taesd.decode(latents[i]) + return decoded + + +def taesd_vae_encode(image): + debug(f'VAE encode: name=TAESD image={image.shape}') + encoded = sd_vae_taesd.encode(image) + return encoded + + +def vae_decode(latents, model, output_type='np', full_quality=True): + t0 = time.time() + prev_job = shared.state.job + shared.state.job = 'vae' + if not torch.is_tensor(latents): # already decoded + return latents + if latents.shape[0] == 0: + shared.log.error(f'VAE nothing to decode: {latents.shape}') + return [] + if shared.state.interrupted or shared.state.skipped: + return [] + if not hasattr(model, 'vae'): + shared.log.error('VAE not found in model') + return [] + if latents.shape[0] == 4 and latents.shape[1] != 4: # likely animatediff latent + latents = latents.permute(1, 0, 2, 3) + if len(latents.shape) == 3: # lost a batch dim in hires + latents = latents.unsqueeze(0) + if full_quality: + decoded = full_vae_decode(latents=latents, model=shared.sd_model) + else: + decoded = taesd_vae_decode(latents=latents) + # TODO validate decoded sample diffusers + # decoded = validate_sample(decoded) + if hasattr(model, 'image_processor'): + imgs = model.image_processor.postprocess(decoded, output_type=output_type) + else: + import diffusers + image_processor = diffusers.image_processor.VaeImageProcessor() + imgs = image_processor.postprocess(decoded, output_type=output_type) + shared.state.job = prev_job + if shared.cmd_opts.profile: + t1 = time.time() + shared.log.debug(f'Profile: VAE decode: {t1-t0:.2f}') + return imgs + + +def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable + if shared.state.interrupted or shared.state.skipped: + return [] + if not hasattr(model, 'vae'): + shared.log.error('VAE not found in model') + return [] + tensor = TF.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae) + if full_quality: + tensor = tensor * 2 - 1 + latents = full_vae_encode(image=tensor, model=shared.sd_model) + else: + latents = taesd_vae_encode(image=tensor) + return latents diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 5a950d443..f3c6b30c6 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -70,7 +70,8 @@ debug_output = os.environ.get('SD_PROMPT_DEBUG', None) -debug = log.info if debug_output is not None else lambda *args, **kwargs: None +debug = log.trace if debug_output is not None else lambda *args, **kwargs: None +debug('Trace: PROMPT') def get_learned_conditioning_prompt_schedules(prompts, steps): diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index abd06b778..6ec9df66c 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -1,14 +1,13 @@ import os +import time import typing import torch from compel import ReturnedEmbeddingsType from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider from modules import shared, prompt_parser, devices - -debug = shared.log.info if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None - - +debug = shared.log.trace if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PROMPT') CLIP_SKIP_MAPPING = { None: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED, 1: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED, @@ -59,31 +58,55 @@ def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List return self.pipe.tokenizer.encode(prompt, add_special_tokens=False) -def encode_prompts(pipe, prompts: list, negative_prompts: list, clip_skip: typing.Optional[int] = None): - if 'StableDiffusion' not in pipe.__class__.__name__: +def get_prompt_schedule(p, prompt, steps): # pylint: disable=unused-argument + t0 = time.time() + temp = [] + schedule = prompt_parser.get_learned_conditioning_prompt_schedules([prompt], steps)[0] + if all(x == schedule[0] for x in schedule): + return [prompt], False + for chunk in schedule: + for s in range(steps): + if len(temp) < s + 1 <= chunk[0]: + temp.append(chunk[1]) + debug(f'Prompt: schedule={temp} time={time.time() - t0}') + return temp, len(schedule) > 1 + + +def encode_prompts(pipe, p, prompts: list, negative_prompts: list, steps: int, step: int = 1, clip_skip: typing.Optional[int] = None): # pylint: disable=unused-argument + if 'StableDiffusion' not in pipe.__class__.__name__ and 'DemoFusion': shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}") return None, None, None, None else: - prompt_embeds = [] - positive_pooleds = [] - negative_embeds = [] - negative_pooleds = [] - for i in range(len(prompts)): - prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, prompts[i], negative_prompts[i], clip_skip) - prompt_embeds.append(prompt_embed) - positive_pooleds.append(positive_pooled) - negative_embeds.append(negative_embed) - negative_pooleds.append(negative_pooled) - - if prompt_embeds is not None: - prompt_embeds = torch.cat(prompt_embeds, dim=0) - if negative_embeds is not None: - negative_embeds = torch.cat(negative_embeds, dim=0) - if positive_pooleds is not None and shared.sd_model_type == "sdxl": - positive_pooleds = torch.cat(positive_pooleds, dim=0) - if negative_pooleds is not None and shared.sd_model_type == "sdxl": - negative_pooleds = torch.cat(negative_pooleds, dim=0) - return prompt_embeds, positive_pooleds, negative_embeds, negative_pooleds + t0 = time.time() + positive_schedule, scheduled = get_prompt_schedule(p, prompts[0], steps) + negative_schedule, neg_scheduled = get_prompt_schedule(p, negative_prompts[0], steps) + p.scheduled_prompt = scheduled or neg_scheduled + + p.prompt_embeds = [] + p.positive_pooleds = [] + p.negative_embeds = [] + p.negative_pooleds = [] + + cache = {} + for i in range(max(len(positive_schedule), len(negative_schedule))): + cached = cache.get(positive_schedule[i % len(positive_schedule)] + negative_schedule[i % len(negative_schedule)], None) + if cached is not None: + prompt_embed, positive_pooled, negative_embed, negative_pooled = cached + else: + prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, + positive_schedule[i % len(positive_schedule)], + negative_schedule[i % len(negative_schedule)], + clip_skip) + if prompt_embed is not None: + p.prompt_embeds.append(torch.cat([prompt_embed] * len(prompts), dim=0)) + if negative_embed is not None: + p.negative_embeds.append(torch.cat([negative_embed] * len(negative_prompts), dim=0)) + if positive_pooled is not None and shared.sd_model_type == "sdxl": + p.positive_pooleds.append(torch.cat([positive_pooled] * len(prompts), dim=0)) + if negative_pooled is not None and shared.sd_model_type == "sdxl": + p.negative_pooleds.append(torch.cat([negative_pooled] * len(negative_prompts), dim=0)) + debug(f"Prompt Parser: Elapsed Time {time.time() - t0}") + return def get_prompts_with_weights(prompt: str): @@ -117,9 +140,9 @@ def prepare_embedding_providers(pipe, clip_skip): def pad_to_same_length(pipe, embeds): device = pipe.device if str(pipe.device) != 'meta' else devices.device - try: #SDXL + try: # SDXL empty_embed = pipe.encode_prompt("") - except Exception: #SD1.5 + except Exception: # SD1.5 empty_embed = pipe.encode_prompt("", device, 1, False) empty_batched = torch.cat([empty_embed[0].to(embeds[0].device)] * embeds[0].shape[0]) max_token_count = max([embed.shape[1] for embed in embeds]) @@ -153,7 +176,7 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c prompt_embeds = [] negative_prompt_embeds = [] pooled_prompt_embeds = None - negative_pooled_prompt_embeds = None + negative_pooled_prompt_embeds = None for i in range(len(embedding_providers)): # add BREAK keyword that splits the prompt into multiple fragments text = positives[i] @@ -163,10 +186,12 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c provider_embed = [] while 'BREAK' in text: pos = text.index('BREAK') - embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True) - provider_embed.append(embed) - text = text[pos+1:] - weights = weights[pos+1:] + debug(f'Prompt: section="{text[:pos]}" len={len(text[:pos])} weights={weights[:pos]}') + if len(text[:pos]) > 0: + embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True) + provider_embed.append(embed) + text = text[pos + 1:] + weights = weights[pos + 1:] prompt_embeds.append(torch.cat(provider_embed, dim=1)) # negative prompt has no keywords embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True) @@ -175,23 +200,24 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c if prompt_embeds[-1].shape[-1] > 768: if shared.opts.diffusers_pooled == "weighted": pooled_prompt_embeds = prompt_embeds[-1][ - torch.arange(prompt_embeds[-1].shape[0], device=device), - (ptokens.to(dtype=torch.int, device=device) == 49407) - .int() - .argmax(dim=-1), - ] + torch.arange(prompt_embeds[-1].shape[0], device=device), + (ptokens.to(dtype=torch.int, device=device) == 49407) + .int() + .argmax(dim=-1), + ] negative_pooled_prompt_embeds = negative_prompt_embeds[-1][ - torch.arange(negative_prompt_embeds[-1].shape[0], device=device), - (ntokens.to(dtype=torch.int, device=device) == 49407) - .int() - .argmax(dim=-1), - ] + torch.arange(negative_prompt_embeds[-1].shape[0], device=device), + (ntokens.to(dtype=torch.int, device=device) == 49407) + .int() + .argmax(dim=-1), + ] else: pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[prompt_2], device=device) if prompt_embeds[-1].shape[-1] > 768 else None negative_pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[neg_prompt_2], device=device) if negative_prompt_embeds[-1].shape[-1] > 768 else None prompt_embeds = torch.cat(prompt_embeds, dim=-1) if len(prompt_embeds) > 1 else prompt_embeds[0] negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=-1) if len(negative_prompt_embeds) > 1 else negative_prompt_embeds[0] + debug(f'Prompt: shape={prompt_embeds.shape} negative={negative_prompt_embeds.shape}') if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]: [prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds]) return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds diff --git a/modules/rife/__init__.py b/modules/rife/__init__.py index a7db7d9a2..7e40735e7 100644 --- a/modules/rife/__init__.py +++ b/modules/rife/__init__.py @@ -11,12 +11,12 @@ from torch.nn import functional as F from tqdm.rich import tqdm from modules.rife.ssim import ssim_matlab -from modules.rife.model_rife import Model +from modules.rife.model_rife import RifeModel from modules import devices, shared model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl' -model = None +model: RifeModel = None def load(model_path: str = 'rife/flownet-v46.pkl'): @@ -26,7 +26,7 @@ def load(model_path: str = 'rife/flownet-v46.pkl'): model_dir = os.path.join(shared.models_path, 'RIFE') model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl') shared.log.debug(f'RIFE load model: file="{model_path}"') - model = Model() + model = RifeModel() model.load_model(model_path, -1) model.eval() model.device() diff --git a/modules/rife/model_rife.py b/modules/rife/model_rife.py index ce8c363cd..52d359923 100644 --- a/modules/rife/model_rife.py +++ b/modules/rife/model_rife.py @@ -6,7 +6,7 @@ from modules import devices -class Model: +class RifeModel: def __init__(self, local_rank=-1): self.flownet = IFNet() self.device() diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 47de61b51..f4825da17 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -101,7 +101,7 @@ def run(self, pp: PostprocessedImage, args): process_args = {} for (name, _component), value in zip(script.controls.items(), script_args): process_args[name] = value - shared.log.debug(f'postprocess: script={script.name} args={process_args}') + shared.log.debug(f'Process: script={script.name} args={process_args}') script.process(pp, **process_args) def create_args_for_run(self, scripts_args): @@ -110,15 +110,25 @@ def create_args_for_run(self, scripts_args): self.setup_ui() scripts = self.scripts_in_preferred_order() args = [None] * max([x.args_to for x in scripts]) - for script in scripts: script_args_dict = scripts_args.get(script.name, None) if script_args_dict is not None: for i, name in enumerate(script.controls): args[script.args_from + i] = script_args_dict.get(name, None) - return args def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + + def postprocess(self, filenames, args): + for script in self.scripts_in_preferred_order(): + if not hasattr(script, 'postprocess'): + continue + shared.state.job = script.name + script_args = args[script.args_from:script.args_to] + process_args = {} + for (name, _component), value in zip(script.controls.items(), script_args): + process_args[name] = value + shared.log.debug(f'Postprocess: script={script.name} args={process_args}') + script.postprocess(filenames, **process_args) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index dd5c68f7c..a1975b0c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -183,11 +183,6 @@ def hijack(self, m): import logging shared.log.info(f"Compiling pipeline={m.model.__class__.__name__} mode={opts.cuda_compile_backend}") import torch._dynamo # pylint: disable=unused-import,redefined-outer-name - if shared.opts.cuda_compile_backend == "openvino_fx": - torch._dynamo.reset() # pylint: disable=protected-access - from modules.intel.openvino import openvino_fx, openvino_clear_caches # pylint: disable=unused-import, no-name-in-module - openvino_clear_caches() - torch._dynamo.eval_frame.check_if_dynamo_supported = lambda: True # pylint: disable=protected-access log_level = logging.WARNING if opts.cuda_compile_verbose else logging.CRITICAL # pylint: disable=protected-access if hasattr(torch, '_logging'): torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py index 6c6213a48..dfd54dc64 100644 --- a/modules/sd_hijack_hypertile.py +++ b/modules/sd_hijack_hypertile.py @@ -48,6 +48,7 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=256 reset_needed = True nhs = possible_tile_sizes(height, tile_size, min_tile_size, swap_size) # possible sub-grids that fit into the image nws = possible_tile_sizes(width, tile_size, min_tile_size, swap_size) + def reset_nhs(): nonlocal nhs, ar ar = height / width # Aspect ratio @@ -73,7 +74,7 @@ def wrapper(*args, **kwargs): out = forward(x, *args[1:], **kwargs) return out if x.ndim == 4: # VAE - # TODO hyperlink vae breaks for diffusers when using non-standard sizes + # TODO hypertile vae breaks for diffusers when using non-standard sizes if nh * nw > 1: x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) out = forward(x, *args[1:], **kwargs) @@ -135,18 +136,17 @@ def wrapper(*args, **kwargs): def context_hypertile_vae(p): - global height, width, max_h, max_w, error_reported # pylint: disable=global-statement - error_reported = False - height=p.height - width=p.width - max_h = 0 - max_w = 0 from modules import shared if p.sd_model is None or not shared.opts.hypertile_vae_enabled: return nullcontext() if shared.opts.cross_attention_optimization == 'Sub-quadratic': shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') return nullcontext() + global height, width, max_h, max_w, error_reported # pylint: disable=global-statement + error_reported = False + error_reported = False + height, width = p.height, p.width + max_h, max_w = 0, 0 vae = getattr(p.sd_model, "vae", None) if shared.backend == shared.Backend.DIFFUSERS else getattr(p.sd_model, "first_stage_model", None) if height % 8 != 0 or width % 8 != 0: log.warning(f'Hypertile VAE disabled: width={width} height={height} are not divisible by 8') @@ -161,18 +161,16 @@ def context_hypertile_vae(p): def context_hypertile_unet(p): - global height, width, max_h, max_w, error_reported # pylint: disable=global-statement - error_reported = False - height=p.height - width=p.width - max_h = 0 - max_w = 0 from modules import shared if p.sd_model is None or not shared.opts.hypertile_unet_enabled: return nullcontext() if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental: shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') return nullcontext() + global height, width, max_h, max_w, error_reported # pylint: disable=global-statement + error_reported = False + height, width = p.height, p.width + max_h, max_w = 0, 0 unet = getattr(p.sd_model, "unet", None) if shared.backend == shared.Backend.DIFFUSERS else getattr(p.sd_model.model, "diffusion_model", None) if height % 8 != 0 or width % 8 != 0: log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8') @@ -192,6 +190,12 @@ def hypertile_set(p, hr=False): if not shared.opts.hypertile_unet_enabled: return error_reported = False - height=p.height if not hr else getattr(p, 'hr_upscale_to_y', p.height) - width=p.width if not hr else getattr(p, 'hr_upscale_to_x', p.width) + if hr: + x = getattr(p, 'hr_upscale_to_x', 0) + y = getattr(p, 'hr_upscale_to_y', 0) + width = y if y > 0 else p.width + height = x if x > 0 else p.height + else: + width=p.width + height=p.height reset_needed = True diff --git a/modules/sd_models.py b/modules/sd_models.py index 0956353ce..675358b9d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,6 +168,7 @@ def list_models(): shared.log.info(f'Available models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}') checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename)) + """ if len(checkpoints_list) == 0: if not shared.cmd_opts.no_download: key = input('Download the default model? (y/N) ') @@ -185,7 +186,7 @@ def list_models(): checkpoint_info = CheckpointInfo(filename) if checkpoint_info.name is not None: checkpoint_info.register() - + """ def update_model_hashes(): txt = [] @@ -474,10 +475,10 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.is_sdxl = False # a1111 compatibility item - model.is_sd2 = False # a1111 compatibility item - model.is_sd1 = True # a1111 compatibility item + model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) @@ -634,6 +635,10 @@ def detect_pipeline(f: str, op: str = 'model', warning=True): if shared.backend == shared.Backend.ORIGINAL: warn(f'Model detected as SD XL instruct pix2pix model, but attempting to load using backend=original: {op}={f} size={size} MB') guess = 'Stable Diffusion XL Instruct' + elif size > 3138 and size < 3142: #3140 + if shared.backend == shared.Backend.ORIGINAL: + warn(f'Model detected as Segmind Vega model, but attempting to load using backend=original: {op}={f} size={size} MB') + guess = 'Stable Diffusion XL' else: guess = 'Stable Diffusion' # guess by name @@ -676,7 +681,19 @@ def detect_pipeline(f: str, op: str = 'model', warning=True): return pipeline, guess -def set_diffuser_options(sd_model, vae, op: str): +def copy_diffuser_options(new_pipe, orig_pipe): + new_pipe.sd_checkpoint_info = orig_pipe.sd_checkpoint_info + new_pipe.sd_model_checkpoint = orig_pipe.sd_model_checkpoint + new_pipe.embedding_db = getattr(orig_pipe, 'embedding_db', None) + new_pipe.sd_model_hash = getattr(orig_pipe, 'sd_model_hash', None) + new_pipe.has_accelerate = getattr(orig_pipe, 'has_accelerate', False) + new_pipe.is_sdxl = getattr(orig_pipe, 'is_sdxl', False) # a1111 compatibility item + new_pipe.is_sd2 = getattr(orig_pipe, 'is_sd2', False) + new_pipe.is_sd1 = getattr(orig_pipe, 'is_sd1', True) + + + +def set_diffuser_options(sd_model, vae = None, op: str = 'model'): if sd_model is None: shared.log.warning(f'{op} is not loaded') return @@ -689,6 +706,18 @@ def set_diffuser_options(sd_model, vae, op: str): if hasattr(sd_model, "watermark"): sd_model.watermark = NoWatermark() sd_model.has_accelerate = False + if hasattr(sd_model, "vae"): + if vae is not None: + sd_model.vae = vae + if shared.opts.diffusers_vae_upcast != 'default': + if shared.opts.diffusers_vae_upcast == 'true': + sd_model.vae.config.force_upcast = True + else: + sd_model.vae.config.force_upcast = False + if shared.opts.no_half_vae: + devices.dtype_vae = torch.float32 + sd_model.vae.to(devices.dtype_vae) + shared.log.debug(f'Setting {op} VAE: name={sd_vae.loaded_vae_file} upcast={sd_model.vae.config.get("force_upcast", None)}') if hasattr(sd_model, "enable_model_cpu_offload"): if (shared.cmd_opts.medvram and devices.backend != "directml") or shared.opts.diffusers_model_cpu_offload: shared.log.debug(f'Setting {op}: enable model CPU offload') @@ -727,18 +756,8 @@ def set_diffuser_options(sd_model, vae, op: str): sd_model.enable_attention_slicing() else: sd_model.disable_attention_slicing() - if hasattr(sd_model, "vae"): - if vae is not None: - sd_model.vae = vae - if shared.opts.diffusers_vae_upcast != 'default': - if shared.opts.diffusers_vae_upcast == 'true': - sd_model.vae.config.force_upcast = True - else: - sd_model.vae.config.force_upcast = False - if shared.opts.no_half_vae: - devices.dtype_vae = torch.float32 - sd_model.vae.to(devices.dtype_vae) - shared.log.debug(f'Setting {op} VAE: name={sd_vae.loaded_vae_file} upcast={sd_model.vae.config.get("force_upcast", None)}') + if hasattr(sd_model, "vqvae"): + sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16 if shared.opts.cross_attention_optimization == "xFormers" and hasattr(sd_model, 'enable_xformers_memory_efficient_attention'): sd_model.enable_xformers_memory_efficient_attention() @@ -806,7 +825,7 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No sd_model = None try: - if shared.cmd_opts.ckpt is not None and model_data.initial: # initial load + if shared.cmd_opts.ckpt is not None and os.path.isdir(shared.cmd_opts.ckpt) and model_data.initial: # initial load ckpt_basename = os.path.basename(shared.cmd_opts.ckpt) model_name = modelloader.find_diffuser(ckpt_basename) if model_name is not None: @@ -833,6 +852,7 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No if vae is not None: diffusers_load_config["vae"] = vae + shared.log.debug(f'Diffusers loading: path="{checkpoint_info.path}"') if os.path.isdir(checkpoint_info.path): err1 = None err2 = None @@ -953,6 +973,10 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init sd_model.sd_checkpoint_info = checkpoint_info # pylint: disable=attribute-defined-outside-init sd_model.sd_model_checkpoint = checkpoint_info.filename # pylint: disable=attribute-defined-outside-init + sd_model.is_sdxl = False # a1111 compatibility item + sd_model.is_sd2 = hasattr(sd_model, 'cond_stage_model') and hasattr(sd_model.cond_stage_model, 'model') # a1111 compatibility item + sd_model.is_sd1 = not sd_model.is_sd2 # a1111 compatibility item + sd_model.logvar = sd_model.logvar.to(devices.device) if hasattr(sd_model, 'logvar') else None # fix for training shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 if hasattr(sd_model, "set_progress_bar_config"): sd_model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining}', ncols=80, colour='#327fba') @@ -1010,6 +1034,8 @@ def set_diffuser_pipe(pipe, new_pipe_type): sd_model_hash = getattr(pipe, "sd_model_hash", None) has_accelerate = getattr(pipe, "has_accelerate", None) embedding_db = getattr(pipe, "embedding_db", None) + image_encoder = getattr(pipe, "image_encoder", None) + feature_extractor = getattr(pipe, "feature_extractor", None) # TODO implement alternative diffusion pipelines """ @@ -1048,6 +1074,11 @@ def set_diffuser_pipe(pipe, new_pipe_type): new_pipe.sd_model_hash = sd_model_hash new_pipe.has_accelerate = has_accelerate new_pipe.embedding_db = embedding_db + new_pipe.image_encoder = image_encoder + new_pipe.feature_extractor = feature_extractor + new_pipe.is_sdxl = getattr(pipe, 'is_sdxl', False) # a1111 compatibility item + new_pipe.is_sd2 = getattr(pipe, 'is_sd2', False) + new_pipe.is_sd1 = getattr(pipe, 'is_sd1', True) shared.log.debug(f"Pipeline class change: original={pipe.__class__.__name__} target={new_pipe.__class__.__name__}") pipe = new_pipe return pipe @@ -1113,20 +1144,23 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, shared.log.debug(f'Model config loaded: {memory_stats()}') sd_model = None stdout = io.StringIO() - with contextlib.redirect_stdout(stdout): - """ - try: - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict - with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): + if os.environ.get('SD_LDM_DEBUG', None) is not None: + sd_model = instantiate_from_config(sd_config.model) + else: + with contextlib.redirect_stdout(stdout): + """ + try: + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + shared.log.error(f'LDM: instantiate from config: {e}') sd_model = instantiate_from_config(sd_config.model) - except Exception as e: - shared.log.error(f'LDM: instantiate from config: {e}') + """ sd_model = instantiate_from_config(sd_config.model) - """ - sd_model = instantiate_from_config(sd_config.model) - for line in stdout.getvalue().splitlines(): - if len(line) > 0: - shared.log.info(f'LDM: {line.strip()}') + for line in stdout.getvalue().splitlines(): + if len(line) > 0: + shared.log.info(f'LDM: {line.strip()}') shared.log.debug(f"Model created from config: {checkpoint_config}") sd_model.used_config = checkpoint_config sd_model.has_accelerate = False diff --git a/modules/sd_models_compile.py b/modules/sd_models_compile.py index 67fc32740..4840e4c49 100644 --- a/modules/sd_models_compile.py +++ b/modules/sd_models_compile.py @@ -9,6 +9,8 @@ class CompiledModelState: def __init__(self): self.first_pass = True + self.first_pass_refiner = True + self.first_pass_vae = True self.height = 512 self.width = 512 self.batch_size = 1 @@ -20,12 +22,15 @@ def __init__(self): self.partitioned_modules = {} -def optimize_ipex(sd_model): +def ipex_optimize(sd_model): try: t0 = time.time() import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - sd_model.unet.training = False - sd_model.unet = ipex.optimize(sd_model.unet, dtype=devices.dtype_unet, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init + if hasattr(sd_model, 'unet'): + sd_model.unet.training = False + sd_model.unet = ipex.optimize(sd_model.unet, dtype=devices.dtype_unet, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init + else: + shared.log.warning('IPEX Optimize enabled but model has no Unet') if hasattr(sd_model, 'vae'): sd_model.vae.training = False sd_model.vae = ipex.optimize(sd_model.vae, dtype=devices.dtype_vae, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init @@ -33,9 +38,29 @@ def optimize_ipex(sd_model): sd_model.movq.training = False sd_model.movq = ipex.optimize(sd_model.movq, dtype=devices.dtype_vae, inplace=True, weights_prepack=False) # pylint: disable=attribute-defined-outside-init t1 = time.time() - shared.log.info(f"Model compile: mode=IPEX-optimize time={t1-t0:.2f}") + shared.log.info(f"IPEX Optimize: time={t1-t0:.2f}") + return sd_model + except Exception as e: + shared.log.warning(f"IPEX Optimize: error: {e}") + +def nncf_compress_weights(sd_model): + try: + t0 = time.time() + import nncf + if hasattr(sd_model, 'unet'): + sd_model.unet = nncf.compress_weights(sd_model.unet) + else: + shared.log.warning('Compress Weights enabled but model has no Unet') + if shared.opts.nncf_compress_vae_weights: + if hasattr(sd_model, 'vae'): + sd_model.vae = nncf.compress_weights(sd_model.vae) + if hasattr(sd_model, 'movq'): + sd_model.movq = nncf.compress_weights(sd_model.movq) + t1 = time.time() + shared.log.info(f"Compress Weights: time={t1-t0:.2f}") + return sd_model except Exception as e: - shared.log.warning(f"Model compile: task=IPEX-optimize error: {e}") + shared.log.warning(f"Compress Weights: error: {e}") def optimize_openvino(): @@ -77,6 +102,7 @@ def compile_stablefast(sd_model): config.enable_cuda_graph = shared.opts.cuda_compile_fullgraph config.enable_jit_freeze = shared.opts.diffusers_eval config.memory_format = torch.channels_last if shared.opts.opt_channelslast else torch.contiguous_format + # config.trace_scheduler = False # config.enable_cnn_optimization # config.prefer_lowp_gemm try: @@ -97,8 +123,6 @@ def compile_torch(sd_model): import torch._dynamo # pylint: disable=unused-import,redefined-outer-name torch._dynamo.reset() # pylint: disable=protected-access shared.log.debug(f"Model compile available backends: {torch._dynamo.list_backends()}") # pylint: disable=protected-access - if shared.opts.ipex_optimize: - optimize_ipex(sd_model) if shared.opts.cuda_compile_backend == "openvino_fx": optimize_openvino() log_level = logging.WARNING if shared.opts.cuda_compile_verbose else logging.CRITICAL # pylint: disable=protected-access @@ -108,12 +132,17 @@ def compile_torch(sd_model): torch._dynamo.config.suppress_errors = shared.opts.cuda_compile_errors # pylint: disable=protected-access t0 = time.time() if shared.opts.cuda_compile: - sd_model.unet = torch.compile(sd_model.unet, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph) + if shared.opts.cuda_compile and (not hasattr(sd_model, 'unet') or not hasattr(sd_model.unet, 'config')): + shared.log.warning('Model compile enabled but model has no Unet') + else: + sd_model.unet = torch.compile(sd_model.unet, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph) if shared.opts.cuda_compile_vae: - if hasattr(sd_model, 'vae'): + if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'decode'): sd_model.vae.decode = torch.compile(sd_model.vae.decode, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph) - if hasattr(sd_model, 'movq'): + elif hasattr(sd_model, 'movq') and hasattr(sd_model.movq, 'decode'): sd_model.movq.decode = torch.compile(sd_model.movq.decode, mode=shared.opts.cuda_compile_mode, backend=shared.opts.cuda_compile_backend, fullgraph=shared.opts.cuda_compile_fullgraph) + else: + shared.log.warning('Model compile enabled but model has no VAE') setup_logging() # compile messes with logging so reset is needed if shared.opts.cuda_compile_precompile: sd_model("dummy prompt") @@ -125,16 +154,16 @@ def compile_torch(sd_model): def compile_diffusers(sd_model): + if shared.opts.ipex_optimize: + sd_model = ipex_optimize(sd_model) + if shared.opts.nncf_compress_weights: + sd_model = nncf_compress_weights(sd_model) if not (shared.opts.cuda_compile or shared.opts.cuda_compile_vae or shared.opts.cuda_compile_upscaler): return sd_model - if not hasattr(sd_model, 'unet') or not hasattr(sd_model.unet, 'config'): - shared.log.warning('Model compile enabled but model has no Unet') - return sd_model if shared.opts.cuda_compile_backend == 'none': shared.log.warning('Model compile enabled but no backend specified') return sd_model - size = 8*getattr(sd_model.unet.config, 'sample_size', 0) - shared.log.info(f"Model compile: pipeline={sd_model.__class__.__name__} shape={size} mode={shared.opts.cuda_compile_mode} backend={shared.opts.cuda_compile_backend} fullgraph={shared.opts.cuda_compile_fullgraph} unet={shared.opts.cuda_compile} vae={shared.opts.cuda_compile_vae} upscaler={shared.opts.cuda_compile_upscaler}") + shared.log.info(f"Model compile: pipeline={sd_model.__class__.__name__} mode={shared.opts.cuda_compile_mode} backend={shared.opts.cuda_compile_backend} fullgraph={shared.opts.cuda_compile_fullgraph} unet={shared.opts.cuda_compile} vae={shared.opts.cuda_compile_vae} upscaler={shared.opts.cuda_compile_upscaler}") if shared.opts.cuda_compile_backend == 'stable-fast': sd_model = compile_stablefast(sd_model) else: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 62f0c8a73..5b710f449 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,7 +1,10 @@ +import os from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_diffusers, shared from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import +debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: SAMPLER') all_samplers = [] all_samplers = [] all_samplers_map = {} @@ -9,6 +12,7 @@ samplers_for_img2img = all_samplers samplers_map = {} + def list_samplers(backend_name = shared.backend): global all_samplers # pylint: disable=global-statement global all_samplers_map # pylint: disable=global-statement diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index dd24d42ae..88f511a5d 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -38,6 +38,8 @@ def single_sample_to_image(sample, approximation=None): warn_once('Unknown decode type, please reset preview method') approximation = 0 + if len(sample.shape) > 4: # likely unknown video latent (e.g. svd) + return Image.new(mode="RGB", size=(512, 512)) if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent sample = sample.permute(1, 0, 2, 3)[0] if approximation == 0: # Simple diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py index 9cd3ee806..f61d7ddde 100644 --- a/modules/sd_samplers_diffusers.py +++ b/modules/sd_samplers_diffusers.py @@ -1,6 +1,12 @@ +import os +import inspect from modules import shared from modules import sd_samplers_common + +debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: SAMPLER') + try: from diffusers import ( DDIMScheduler, @@ -26,22 +32,22 @@ config = { # beta_start, beta_end are typically per-scheduler, but we don't want them as they should be taken from the model itself as those are values model was trained on # prediction_type is ideally set in model as well, but it maybe needed that we do auto-detect of model type in the future - 'All': { 'num_train_timesteps': 1000, 'beta_start': 0.0001, 'beta_end': 0.02, 'beta_schedule': 'linear', 'prediction_type': 'epsilon' }, + 'All': { 'num_train_timesteps': 500, 'beta_start': 0.0001, 'beta_end': 0.02, 'beta_schedule': 'linear', 'prediction_type': 'epsilon' }, 'DDIM': { 'clip_sample': True, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'linspace', 'rescale_betas_zero_snr': False }, 'DDPM': { 'variance_type': "fixed_small", 'clip_sample': True, 'thresholding': False, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'linspace'}, 'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True }, 'DPM++ 1S': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False }, 'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False }, 'DPM SDE': { 'use_karras_sigmas': False }, - 'Euler a': { }, - 'Euler': { 'interpolation_type': "linear", 'use_karras_sigmas': False }, + 'Euler a': { 'rescale_betas_zero_snr': False }, + 'Euler': { 'interpolation_type': "linear", 'use_karras_sigmas': False, 'rescale_betas_zero_snr': False }, 'Heun': { 'use_karras_sigmas': False }, 'KDPM2': { 'steps_offset': 0 }, 'KDPM2 a': { 'steps_offset': 0 }, 'LMSD': { 'use_karras_sigmas': False, 'timestep_spacing': 'linspace', 'steps_offset': 0 }, 'PNDM': { 'skip_prk_steps': False, 'set_alpha_to_one': False, 'steps_offset': 0 }, 'UniPC': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'predict_x0': 'bh2', 'lower_order_final': True }, - 'LCM': { 'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False }, + 'LCM': { 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False }, } samplers_data_diffusers = [ @@ -108,9 +114,21 @@ def __init__(self, name, constructor, model, **kwargs): self.config['beta_start'] = shared.opts.schedulers_beta_start if 'beta_end' in self.config and shared.opts.schedulers_beta_end > 0: self.config['beta_end'] = shared.opts.schedulers_beta_end + if 'rescale_betas_zero_snr' in self.config: + self.config['rescale_betas_zero_snr'] = shared.opts.schedulers_rescale_betas + if 'num_train_timesteps' in self.config: + self.config['num_train_timesteps'] = shared.opts.schedulers_timesteps_range if name == 'DPM++ 2M': self.config['algorithm_type'] = shared.opts.schedulers_dpm_solver if name == 'DEIS': self.config['algorithm_type'] = 'deis' + # validate all config params + signature = inspect.signature(constructor, follow_wrapped=True) + possible = signature.parameters.keys() + debug(f'Sampler: sampler="{name}" config={self.config} signature={possible}') + for key in self.config.copy().keys(): + if key not in possible: + shared.log.warning(f'Sampler: sampler="{name}" config={self.config} invalid={key}') + del self.config[key] self.sampler = constructor(**self.config) self.sampler.name = name diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9e28e4c73..f3707e3d8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -147,21 +147,25 @@ def load_vae(model, vae_file=None, vae_source="unknown-source"): global loaded_vae_file # pylint: disable=global-statement cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 if vae_file: - if cache_enabled and vae_file in checkpoints_loaded: - # use vae checkpoint cache - shared.log.info(f"Loading VAE: model={get_filename(vae_file)} source={vae_source} cached=True") - store_base_vae(model) - _load_vae_dict(model, checkpoints_loaded[vae_file]) - else: - if not os.path.isfile(vae_file): - shared.log.error(f"VAE not found: model={vae_file} source={vae_source}") - return - store_base_vae(model) - vae_dict_1 = load_vae_dict(vae_file) - _load_vae_dict(model, vae_dict_1) - if cache_enabled: - # cache newly loaded vae - checkpoints_loaded[vae_file] = vae_dict_1.copy() + try: + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + shared.log.info(f"Loading VAE: model={get_filename(vae_file)} source={vae_source} cached=True") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + if not os.path.isfile(vae_file): + shared.log.error(f"VAE not found: model={vae_file} source={vae_source}") + return + store_base_vae(model) + vae_dict_1 = load_vae_dict(vae_file) + _load_vae_dict(model, vae_dict_1) + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + except Exception as e: + shared.log.error(f"Loading VAE failed: model={vae_file} source={vae_source} {e}") + restore_base_vae(model) # clean up cache if limit is reached if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model @@ -270,8 +274,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): if hasattr(shared.sd_model, "vae") and hasattr(shared.sd_model, "sd_checkpoint_info"): vae = load_vae_diffusers(shared.sd_model.sd_checkpoint_info.filename, vae_file, vae_source) if vae is not None: - if vae is not None: - sd_model.vae = vae + sd_models.set_diffuser_options(sd_model, vae=vae, op='vae') if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram and not getattr(sd_model, 'has_accelerate', False): sd_model.to(devices.device) diff --git a/modules/shared.py b/modules/shared.py index 7c2bea684..30c06a1fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -63,6 +63,7 @@ "outdir_save", "outdir_init_images" } +resize_modes = ["None", "Fixed", "Crop", "Fill", "Latent"] compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order'] console = Console(log_time=True, log_time_format='%H:%M:%S-%f') @@ -151,6 +152,11 @@ def refresh_vaes(): modules.sd_vae.refresh_vae_list() +def refresh_upscalers(): + import modules.modelloader # pylint: disable=W0621 + modules.modelloader.load_upscalers() + + def list_samplers(): import modules.sd_samplers # pylint: disable=W0621 modules.sd_samplers.set_samplers() @@ -166,17 +172,14 @@ def temp_disable_extensions(): for ext in disable_safe: if ext not in opts.disabled_extensions: disabled.append(ext) - log.info(f'Safe mode disabling extensions: {disabled}') if backend == Backend.DIFFUSERS: for ext in disable_diffusers: if ext not in opts.disabled_extensions: disabled.append(ext) - log.info(f'Disabling uncompatible extensions: backend={backend} {disabled}') if backend == Backend.ORIGINAL: for ext in disable_original: if ext not in opts.disabled_extensions: disabled.append(ext) - log.info(f'Disabling uncompatible extensions: backend={backend} {disabled}') cmd_opts.controlnet_loglevel = 'WARNING' return disabled @@ -188,6 +191,7 @@ def readfile(filename, silent=False, lock=False): try: if not os.path.exists(filename): return {} + t0 = time.time() if lock: lock_file = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log) locked = lock_file.acquire_read_lock(blocking=True, timeout=3) @@ -195,8 +199,9 @@ def readfile(filename, silent=False, lock=False): data = json.load(file) if type(data) is str: data = json.loads(data) + t1 = time.time() if not silent: - log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)}') + log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f}') except Exception as e: if not silent: log.error(f'Reading failed: {filename} {e}') @@ -209,15 +214,18 @@ def readfile(filename, silent=False, lock=False): return data -def writefile(data, filename, mode='w', silent=False): +def writefile(data, filename, mode='w', silent=False, atomic=False): lock = None locked = False + import tempfile + def default(obj): log.error(f"Saving: {filename} not a valid object: {obj}") return str(obj) try: + t0 = time.time() # skipkeys=True, ensure_ascii=True, check_circular=True, allow_nan=True if type(data) == dict: output = json.dumps(data, indent=2, default=default) @@ -233,12 +241,21 @@ def default(obj): raise ValueError('not a valid object') lock = fasteners.InterProcessReaderWriterLock(f"{filename}.lock", logger=log) locked = lock.acquire_write_lock(blocking=True, timeout=3) - with open(filename, mode, encoding="utf8") as file: - file.write(output) + if atomic: + with tempfile.NamedTemporaryFile(mode=mode, encoding="utf8", delete=False, dir=os.path.dirname(filename)) as f: + f.write(output) + f.flush() + os.fsync(f.fileno()) + os.replace(f.name, filename) + else: + with open(filename, mode=mode, encoding="utf8") as file: + file.write(output) + t1 = time.time() if not silent: - log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)}') + log.debug(f'Save: file="{filename}" json={len(data)} bytes={len(output)} time={t1-t0:.3f}') except Exception as e: log.error(f'Saving failed: {filename} {e}') + errors.display(e, 'Saving failed') finally: if lock is not None: lock.release_read_lock() @@ -261,7 +278,7 @@ def default(obj): options_templates.update(options_section(('sd', "Execution & Models"), { - "sd_backend": OptionInfo("diffusers" if cmd_opts.use_openvino else "original", "Execution backend", gr.Radio, {"choices": ["original", "diffusers"] }), + "sd_backend": OptionInfo("original" if not cmd_opts.use_openvino else "diffusers", "Execution backend", gr.Radio, {"choices": ["original", "diffusers"] }), "sd_checkpoint_autoload": OptionInfo(True, "Model autoload on server start"), "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints), @@ -281,9 +298,9 @@ def default(obj): "math_sep": OptionInfo("

Execution precision

", "", gr.HTML), "precision": OptionInfo("Autocast", "Precision type", gr.Radio, {"choices": ["Autocast", "Full"]}), "cuda_dtype": OptionInfo("FP32" if sys.platform == "darwin" or cmd_opts.use_openvino else "BF16" if devices.backend == "ipex" else "FP16", "Device precision type", gr.Radio, {"choices": ["FP32", "FP16", "BF16"]}), - "no_half": OptionInfo(True if cmd_opts.use_openvino else False, "Use full precision for model (--no-half)", None, None, None), - "no_half_vae": OptionInfo(True if cmd_opts.use_openvino else False, "Use full precision for VAE (--no-half-vae)"), - "upcast_sampling": OptionInfo(True if sys.platform == "darwin" else False, "Enable upcast sampling"), + "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for model (--no-half)", None, None, None), + "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Use full precision for VAE (--no-half-vae)"), + "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Enable upcast sampling"), "upcast_attn": OptionInfo(False, "Enable upcast cross attention layer"), "cuda_cast_unet": OptionInfo(False, "Use fixed UNet precision"), "disable_nan_check": OptionInfo(True, "Disable NaN check in produced images/latent spaces", gr.Checkbox, {"visible": False}), @@ -303,25 +320,27 @@ def default(obj): "torch_gc_threshold": OptionInfo(80 if devices.backend == "ipex" else 90, "VRAM usage threshold before running Torch GC to clear up VRAM", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}), "cuda_compile_sep": OptionInfo("

Model Compile

", "", gr.HTML), - "cuda_compile": OptionInfo(True if cmd_opts.use_openvino else False, "Compile UNet"), - "cuda_compile_vae": OptionInfo(True if cmd_opts.use_openvino else False, "Compile VAE"), - "cuda_compile_upscaler": OptionInfo(True if cmd_opts.use_openvino else False, "Compile upscaler"), - "cuda_compile_backend": OptionInfo("openvino_fx" if cmd_opts.use_openvino else "none", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'ipex', 'openvino_fx', 'stable-fast']}), + "cuda_compile": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile UNet"), + "cuda_compile_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile VAE"), + "cuda_compile_upscaler": OptionInfo(False if not cmd_opts.use_openvino else True, "Compile upscaler"), + "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'ipex', 'openvino_fx', 'stable-fast']}), "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}), "cuda_compile_fullgraph": OptionInfo(False, "Model compile fullgraph"), - "cuda_compile_precompile": OptionInfo(False if cmd_opts.use_openvino else True, "Model compile precompile"), + "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"), "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"), "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"), "ipex_sep": OptionInfo("

IPEX, DirectML and OpenVINO

", "", gr.HTML), - "ipex_optimize": OptionInfo(True if devices.backend == "ipex" else False, "Enable IPEX Optimize for Intel GPUs"), - "ipex_optimize_upscaler": OptionInfo(True if devices.backend == "ipex" else False, "Enable IPEX Optimize for Intel GPUs with Upscalers"), + "ipex_optimize": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs"), + "ipex_optimize_upscaler": OptionInfo(False if not devices.backend == "ipex" else True, "Enable IPEX Optimize for Intel GPUs with Upscalers"), "directml_memory_provider": OptionInfo(default_memory_provider, 'DirectML memory stats provider', gr.Radio, {"choices": memory_providers}), "directml_catch_nan": OptionInfo(False, "DirectML retry specific operation when NaN is produced if possible. (makes generation slower)"), "openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching"), "openvino_hetero_gpu": OptionInfo(False, "OpenVINO use Hetero Device for single inference with multiple devices"), "openvino_remove_cpu_from_hetero": OptionInfo(False, "OpenVINO remove CPU from Hetero Device"), "openvino_remove_igpu_from_hetero": OptionInfo(False, "OpenVINO remove iGPU from Hetero Device"), + "nncf_compress_weights": OptionInfo(False, "Compress Model weights to 8 bit with NNCF"), + "nncf_compress_vae_weights": OptionInfo(False, "Compress VAE weights to 8 bit with NNCF"), })) options_templates.update(options_section(('advanced', "Inference Settings"), { @@ -355,12 +374,12 @@ def default(obj): "diffusers_move_unet": OptionInfo(True, "Move base model to CPU when using VAE"), "diffusers_move_refiner": OptionInfo(True, "Move refiner model to CPU when not in use"), "diffusers_extract_ema": OptionInfo(True, "Use model EMA weights when possible"), - "diffusers_generator_device": OptionInfo("default", "Generator device", gr.Radio, {"choices": ["default", "cpu"]}), + "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}), "diffusers_model_cpu_offload": OptionInfo(False, "Enable model CPU offload (--medvram)"), "diffusers_seq_cpu_offload": OptionInfo(False, "Enable sequential CPU offload (--lowvram)"), "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}), "diffusers_vae_slicing": OptionInfo(True, "Enable VAE slicing"), - "diffusers_vae_tiling": OptionInfo(False if cmd_opts.use_openvino else True, "Enable VAE tiling"), + "diffusers_vae_tiling": OptionInfo(True if not cmd_opts.use_openvino else False, "Enable VAE tiling"), "diffusers_attention_slicing": OptionInfo(False, "Enable attention slicing"), "diffusers_model_load_variant": OptionInfo("default", "Diffusers model loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), "diffusers_vae_load_variant": OptionInfo("default", "Diffusers VAE loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), @@ -384,6 +403,7 @@ def default(obj): "styles_dir": OptionInfo(os.path.join(paths.data_path, 'styles.csv'), "File or Folder with user-defined styles", folder=True), "embeddings_dir": OptionInfo(os.path.join(paths.models_path, 'embeddings'), "Folder with textual inversion embeddings", folder=True), "hypernetwork_dir": OptionInfo(os.path.join(paths.models_path, 'hypernetworks'), "Folder with Hypernetwork models", folder=True), + "control_dir": OptionInfo(os.path.join(paths.models_path, 'control'), "Folder with Control models", folder=True), "codeformer_models_path": OptionInfo(os.path.join(paths.models_path, 'Codeformer'), "Folder with codeformer models", folder=True), "gfpgan_models_path": OptionInfo(os.path.join(paths.models_path, 'GFPGAN'), "Folder with GFPGAN models", folder=True), "esrgan_models_path": OptionInfo(os.path.join(paths.models_path, 'ESRGAN'), "Folder with ESRGAN models", folder=True), @@ -395,6 +415,7 @@ def default(obj): "clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True), "other_paths_sep_options": OptionInfo("

Other paths

", "", gr.HTML), + "openvino_cache_path": OptionInfo('cache', "Directory for OpenVINO cache", folder=True), "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True), "clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"), })) @@ -469,8 +490,8 @@ def default(obj): "gallery_height": OptionInfo("", "Gallery height", gr.Textbox), "compact_view": OptionInfo(False, "Compact view"), "return_grid": OptionInfo(True, "Show grid in results"), - "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results"), - "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results"), + "return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"), + "return_mask_composite": OptionInfo(False, "Inpainting include masked composite in results"), "disable_weights_auto_swap": OptionInfo(True, "Do not change selected model when reading generation parameters"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), @@ -487,7 +508,7 @@ def default(obj): "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid", gr.Checkbox, {"visible": False}), "notification_audio_enable": OptionInfo(False, "Play a sound when images are finished generating"), "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True), - "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}), "live_preview_content": OptionInfo("Combined", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"], "visible": False}), "live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}), @@ -516,6 +537,8 @@ def default(obj): "schedulers_beta_schedule": OptionInfo("default", "Beta schedule", gr.Radio, {"choices": ['default', 'linear', 'scaled_linear', 'squaredcos_cap_v2']}), 'schedulers_beta_start': OptionInfo(0, "Beta start", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}), 'schedulers_beta_end': OptionInfo(0, "Beta end", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.00001}), + 'schedulers_timesteps_range': OptionInfo(1000, "Timesteps range", gr.Slider, {"minimum": 250, "maximum": 4000, "step": 1}), + "schedulers_rescale_betas": OptionInfo(False, "Rescale betas with zero terminal SNR", gr.Checkbox), # managed from ui.py for backend original k-diffusion "schedulers_sep_kdiffusers": OptionInfo("

K-Diffusion specific config

", "", gr.HTML), @@ -540,9 +563,11 @@ def default(obj): options_templates.update(options_section(('postprocessing', "Postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable additional postprocessing operations", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + "postprocessing_sep_img2img": OptionInfo("

Img2Img & Inpainting

", "", gr.HTML), - "img2img_color_correction": OptionInfo(False, "Apply color correction to match original colors"), - "img2img_fix_steps": OptionInfo(False, "For image processing do exact number of steps as specified"), + "img2img_color_correction": OptionInfo(False, "Apply color correction"), + # "img2img_apply_overlay": OptionInfo(False, "Apply result as overlay"), + "img2img_fix_steps": OptionInfo(False, "For image processing do exact number of steps as specified", gr.Checkbox, { "visible": False }), "img2img_background_color": OptionInfo("#ffffff", "Image transparent color fill", ui_components.FormColorPicker, {}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for image processing", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.01}), @@ -556,7 +581,7 @@ def default(obj): "postprocessing_sep_upscalers": OptionInfo("

Upscaling

", "", gr.HTML), "upscaler_unload": OptionInfo(False, "Unload upscaler after processing"), # 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": False}), - "upscaler_for_img2img": OptionInfo("None", "Default upscaler for image resize operations", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), + "upscaler_for_img2img": OptionInfo("None", "Default upscaler for image resize operations", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers], "visible": False}, refresh=refresh_upscalers), "upscaler_tile_size": OptionInfo(192, "Upscaler tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "upscaler_tile_overlap": OptionInfo(8, "Upscaler tile overlap", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), })) @@ -601,18 +626,17 @@ def default(obj): "extra_networks_card_size": OptionInfo(160, "UI card size (px)", gr.Slider, {"minimum": 20, "maximum": 2000, "step": 1}), "extra_networks_card_square": OptionInfo(True, "UI disable variable aspect ratio"), "extra_networks_card_fit": OptionInfo("cover", "UI image contain method", gr.Radio, {"choices": ["contain", "cover", "fill"], "visible": False}), - "extra_networks_sep2": OptionInfo("

Extra networks general

", "", gr.HTML), - "extra_network_skip_indexing": OptionInfo(False, "Do not automatically build extra network pages", gr.Checkbox), + "extra_network_skip_indexing": OptionInfo(False, "Build info on first access", gr.Checkbox), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "extra_networks_sep3": OptionInfo("

Extra networks settings

", "", gr.HTML), "extra_networks_styles": OptionInfo(True, "Show built-in styles"), "lora_preferred_name": OptionInfo("filename", "LoRA preffered name", gr.Radio, {"choices": ["filename", "alias"]}), "lora_add_hashes_to_infotext": OptionInfo(True, "LoRA add hash info"), + "lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use alternative loading method"), + "lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use merge when using alternative method"), "lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}), "lora_functional": OptionInfo(False, "Use Kohya method for handling multiple LoRA", gr.Checkbox, { "visible": False }), - "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, { "choices": ["None"], "visible": False }), })) @@ -923,8 +947,9 @@ def req(url_addr, headers = None, **kwargs): class Shared(sys.modules[__name__].__class__): # this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than at program startup. @property def sd_model(self): - # log.debug(f'Access shared.sd_model: {sys._getframe().f_back.f_code.co_name}') # pylint: disable=protected-access import modules.sd_models # pylint: disable=W0621 + if modules.sd_models.model_data.sd_model is None: + log.debug(f'Model requested: fn={sys._getframe().f_back.f_code.co_name}') # pylint: disable=protected-access return modules.sd_models.model_data.get_sd_model() @sd_model.setter diff --git a/modules/shared_state.py b/modules/shared_state.py index 81ebbf120..c597417f1 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -21,6 +21,7 @@ class State: current_image_sampling_step = 0 id_live_preview = 0 textinfo = None + api = False time_start = None need_restart = False server_start = time.time() @@ -58,7 +59,7 @@ def dict(self): } return obj - def begin(self, title=""): + def begin(self, title="", api=None): import modules.devices self.total_jobs += 1 self.current_image = None @@ -74,12 +75,13 @@ def begin(self, title=""): self.sampling_step = 0 self.skipped = False self.textinfo = None + self.api = api if api is not None else self.api self.time_start = time.time() if self.debug_output: log.debug(f'State begin: {self.job}') modules.devices.torch_gc() - def end(self): + def end(self, api=None): import modules.devices if self.time_start is None: # someone called end before being log.debug(f'Access state.end: {sys._getframe().f_back.f_code.co_name}') # pylint: disable=protected-access @@ -92,20 +94,21 @@ def end(self): self.paused = False self.interrupted = False self.skipped = False + self.api = api if api is not None else self.api modules.devices.torch_gc() def set_current_image(self): from modules.shared import opts, cmd_opts """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" - if cmd_opts.lowvram: + if cmd_opts.lowvram or self.api: return if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps > 0: self.do_set_current_image() def do_set_current_image(self): - from modules.shared import opts - if self.current_latent is None: + if self.current_latent is None or self.api: return + from modules.shared import opts import modules.sd_samplers # pylint: disable=W0621 try: image = modules.sd_samplers.samples_to_image_grid(self.current_latent) if opts.show_progress_grid else modules.sd_samplers.sample_to_image(self.current_latent) diff --git a/modules/theme.py b/modules/theme.py index 83a14c902..35f98fe55 100644 --- a/modules/theme.py +++ b/modules/theme.py @@ -1,6 +1,5 @@ import os import json -import urllib.request import gradio as gr import modules.shared # from modules.shared import log, opts, req, writefile @@ -51,18 +50,21 @@ def reload_gradio_theme(theme_name=None): if not theme_name: theme_name = modules.shared.opts.gradio_theme default_font_params = {} + """ res = 0 try: + import urllib.request request = urllib.request.Request("https://fonts.googleapis.com/css2?family=IBM+Plex+Mono", method="HEAD") res = urllib.request.urlopen(request, timeout=3.0).status # pylint: disable=consider-using-with except Exception: res = 0 if res != 200: modules.shared.log.info('No internet access detected, using default fonts') - default_font_params = { - 'font':['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], - 'font_mono':['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'] - } + """ + default_font_params = { + 'font':['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], + 'font_mono':['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'] + } is_builtin = theme_name in list_builtin_themes() modules.shared.log.info(f'Load UI theme: name="{theme_name}" style={modules.shared.opts.theme_style} base={"sdnext.css" if is_builtin else "base.css"}') if is_builtin: diff --git a/modules/txt2img.py b/modules/txt2img.py index 55751bb2d..deb87caba 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,12 +1,31 @@ +import os import modules.scripts from modules import sd_samplers, shared, processing from modules.generation_parameters_copypaste import create_override_settings_dict from modules.ui import plaintext_to_html -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, latent_index: int, full_quality: bool, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, diffusers_guidance_rescale: float, clip_skip: int, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_force: bool, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, refiner_steps: int, refiner_start: int, refiner_prompt: str, refiner_negative: str, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, override_settings_texts, *args): # pylint: disable=unused-argument +debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PROCESS') - shared.log.debug(f'txt2img: id_task={id_task}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|steps={steps}|sampler_index={sampler_index}|latent_index={latent_index}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|clip_skip={clip_skip}|seed={seed}|subseed={subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}||height={height}|width={width}|enable_hr={enable_hr}|denoising_strength={denoising_strength}|hr_scale={hr_scale}|hr_upscaler={hr_upscaler}|hr_force={hr_force}|hr_second_pass_steps={hr_second_pass_steps}|hr_resize_x={hr_resize_x}|hr_resize_y={hr_resize_y}|image_cfg_scale={image_cfg_scale}|diffusers_guidance_rescale={diffusers_guidance_rescale}|refiner_steps={refiner_steps}|refiner_start={refiner_start}|refiner_prompt={refiner_prompt}|refiner_negative={refiner_negative}|override_settings_texts={override_settings_texts}') + +def txt2img(id_task, + prompt, negative_prompt, prompt_styles, + steps, sampler_index, latent_index, + full_quality, restore_faces, tiling, + n_iter, batch_size, + cfg_scale, image_cfg_scale, diffusers_guidance_rescale, + clip_skip, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + height, width, + enable_hr, denoising_strength, + hr_scale, hr_upscaler, hr_force, hr_second_pass_steps, hr_resize_x, hr_resize_y, + refiner_steps, refiner_start, refiner_prompt, refiner_negative, + hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + override_settings_texts, + *args): + + debug(f'txt2img: id_task={id_task}|prompt={prompt}|negative={negative_prompt}|styles={prompt_styles}|steps={steps}|sampler_index={sampler_index}|latent_index={latent_index}|full_quality={full_quality}|restore_faces={restore_faces}|tiling={tiling}|batch_count={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|clip_skip={clip_skip}|seed={seed}|subseed={subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|height={height}|width={width}|enable_hr={enable_hr}|denoising_strength={denoising_strength}|hr_scale={hr_scale}|hr_upscaler={hr_upscaler}|hr_force={hr_force}|hr_second_pass_steps={hr_second_pass_steps}|hr_resize_x={hr_resize_x}|hr_resize_y={hr_resize_y}|image_cfg_scale={image_cfg_scale}|diffusers_guidance_rescale={diffusers_guidance_rescale}|refiner_steps={refiner_steps}|refiner_start={refiner_start}|refiner_prompt={refiner_prompt}|refiner_negative={refiner_negative}|override_settings={override_settings_texts}') if shared.sd_model is None: shared.log.warning('Model not loaded') diff --git a/modules/ui.py b/modules/ui.py index 6b076469b..6bba53a3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -9,41 +9,37 @@ import numpy as np from PIL import Image from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, ui_loadsave, ui_train, ui_models, ui_interrogate +from modules import timer, shared, theme, sd_models, script_callbacks, modelloader, prompt_parser, ui_common, ui_loadsave, ui_symbols, generation_parameters_copypaste from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path -from modules.shared import opts, cmd_opts from modules.dml import directml_override_opts -from modules import prompt_parser -from modules import timer -import modules.ui_symbols as symbols -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.hypernetworks.ui import modules.scripts -import modules.shared -import modules.errors -import modules.styles -import modules.extras -import modules.theme import modules.textual_inversion.ui -import modules.sd_samplers +import modules.hypernetworks.ui +import modules.errors modules.errors.install() mimetypes.init() mimetypes.add_type('application/javascript', '.js') -log = modules.shared.log +log = shared.log +opts = shared.opts +cmd_opts = shared.cmd_opts ui_system_tabs = None -switch_values_symbol = symbols.switch -detect_image_size_symbol = symbols.detect -paste_symbol = symbols.paste -clear_prompt_symbol = symbols.clear -restore_progress_symbol = symbols.apply -folder_symbol = symbols.folder -extra_networks_symbol = symbols.networks -apply_style_symbol = symbols.apply -save_style_symbol = symbols.save +switch_values_symbol = ui_symbols.switch +detect_image_size_symbol = ui_symbols.detect +paste_symbol = ui_symbols.paste +clear_prompt_symbol = ui_symbols.clear +restore_progress_symbol = ui_symbols.apply +folder_symbol = ui_symbols.folder +extra_networks_symbol = ui_symbols.networks +apply_style_symbol = ui_symbols.apply +save_style_symbol = ui_symbols.save +txt2img_paste_fields = [] +img2img_paste_fields = [] +txt2img_args = [] +img2img_args = [] +paste_function = None if not cmd_opts.share and not cmd_opts.listen: @@ -56,18 +52,15 @@ def gr_show(visible=True): return {"visible": visible, "__type__": "update"} -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None -paste_function = None - - def create_output_panel(tabname, outdir): # pylint: disable=unused-argument # outdir is used by extensions a, b, c, _d, e = ui_common.create_output_panel(tabname) return a, b, c, e + def plaintext_to_html(text): # may be referenced by extensions return ui_common.plaintext_to_html(text) + def infotext_to_html(text): # may be referenced by extensions return ui_common.infotext_to_html(text) @@ -75,16 +68,17 @@ def infotext_to_html(text): # may be referenced by extensions def send_gradio_gallery_to_image(x): if len(x) == 0: return None - return parameters_copypaste.image_from_url_text(x[0]) + return generation_parameters_copypaste.image_from_url_text(x[0]) def add_style(name: str, prompt: str, negative_prompt: str): + from modules import styles if name is None: return [gr_show() for x in range(4)] - style = modules.styles.Style(name, prompt, negative_prompt) - modules.shared.prompt_styles.styles[style.name] = style - modules.shared.prompt_styles.save_styles(modules.shared.opts.styles_dir) - return [gr.Dropdown.update(visible=True, choices=list(modules.shared.prompt_styles.styles)) for _ in range(2)] + style = styles.Style(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.opts.styles_dir) + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler): @@ -107,8 +101,8 @@ def resize_from_to_html(width, height, scale_by): def apply_styles(prompt, prompt_neg, styles): - prompt = modules.shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = modules.shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] @@ -128,7 +122,7 @@ def process_interrogate(interrogation_function, mode, ii_input_files, ii_input_d if not os.path.isdir(ii_input_dir): log.error(f"Interrogate: Input directory not found: {ii_input_dir}") return [gr.update(), None] - images = modules.shared.listfiles(ii_input_dir) + images = shared.listfiles(ii_input_dir) if ii_output_dir != "": os.makedirs(ii_output_dir, exist_ok=True) else: @@ -145,47 +139,139 @@ def interrogate(image): if image is None: log.error("Interrogate: no image selected") return gr.update() - prompt = modules.shared.interrogator.interrogate(image.convert("RGB")) + prompt = shared.interrogator.interrogate(image.convert("RGB")) return gr.update() if prompt is None else prompt def interrogate_deepbooru(image): + from modules import deepbooru prompt = deepbooru.model.tag(image) return gr.update() if prompt is None else prompt -def create_seed_inputs(tab): +def create_batch_inputs(tab): + with gr.Accordion(open=False, label="Batch", elem_id=f"{tab}_batch", elem_classes=["small-accordion"]): + with FormRow(elem_id=f"{tab}_row_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id=f"{tab}_batch_count") + batch_size = gr.Slider(minimum=1, maximum=32, step=1, label='Batch size', value=1, elem_id=f"{tab}_batch_size") + batch_switch_btn = ToolButton(value=ui_symbols.switch, elem_id=f"{tab}_batch_switch_btn", label="Switch dims") + batch_switch_btn.click(lambda w, h: (h, w), inputs=[batch_count, batch_size], outputs=[batch_count, batch_size], show_progress=False) + return batch_count, batch_size + + +def create_seed_inputs(tab, reuse_visible=True): with gr.Accordion(open=False, label="Seed", elem_id=f"{tab}_seed_group", elem_classes=["small-accordion"]): with FormRow(elem_id=f"{tab}_seed_row", variant="compact"): seed = gr.Number(label='Initial seed', value=-1, elem_id=f"{tab}_seed", container=True) - random_seed = ToolButton(symbols.random, elem_id=f"{tab}_random_seed", label='Random seed') - reuse_seed = ToolButton(symbols.reuse, elem_id=f"{tab}_reuse_seed", label='Reuse seed') - with FormRow(visible=True, elem_id=f"{tab}_subseed_row", variant="compact"): + random_seed = ToolButton(ui_symbols.random, elem_id=f"{tab}_random_seed", label='Random seed') + reuse_seed = ToolButton(ui_symbols.reuse, elem_id=f"{tab}_reuse_seed", label='Reuse seed', visible=reuse_visible) + with FormRow(elem_id=f"{tab}_subseed_row", variant="compact", visible=shared.backend==shared.Backend.ORIGINAL): subseed = gr.Number(label='Variation', value=-1, elem_id=f"{tab}_subseed", container=True) - random_subseed = ToolButton(symbols.random, elem_id=f"{tab}_random_subseed") - reuse_subseed = ToolButton(symbols.reuse, elem_id=f"{tab}_reuse_subseed") + random_subseed = ToolButton(ui_symbols.random, elem_id=f"{tab}_random_subseed") + reuse_subseed = ToolButton(ui_symbols.reuse, elem_id=f"{tab}_reuse_subseed", visible=reuse_visible) subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{tab}_subseed_strength") with FormRow(visible=False): seed_resize_from_w = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize seed from width", value=0, elem_id=f"{tab}_seed_resize_from_w") seed_resize_from_h = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize seed from height", value=0, elem_id=f"{tab}_seed_resize_from_h") random_seed.click(fn=lambda: [-1, -1], show_progress=False, inputs=[], outputs=[seed, subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w + + +def create_advanced_inputs(tab): + with gr.Accordion(open=False, label="Advanced", elem_id=f"{tab}_advanced", elem_classes=["small-accordion"]): + with gr.Group(): + with FormRow(): + cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='CFG scale', value=4.0, elem_id=f"{tab}_cfg_scale") + clip_skip = gr.Slider(label='CLIP skip', value=1, minimum=1, maximum=14, step=1, elem_id=f"{tab}_clip_skip", interactive=True) + with FormRow(): + image_cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='Secondary CFG scale', value=4.0, elem_id=f"{tab}_image_cfg_scale") + diffusers_guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Guidance rescale', value=0.7, elem_id=f"{tab}_image_cfg_rescale", visible=shared.backend == shared.Backend.DIFFUSERS) + with gr.Group(): + with FormRow(): + full_quality = gr.Checkbox(label='Full quality', value=True, elem_id=f"{tab}_full_quality") + restore_faces = gr.Checkbox(label='Face restore', value=False, visible=len(shared.face_restorers) > 1, elem_id=f"{tab}_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id=f"{tab}_tiling", visible=shared.backend == shared.Backend.ORIGINAL) + with gr.Group(visible=shared.backend == shared.Backend.DIFFUSERS): + with FormRow(): + hdr_clamp = gr.Checkbox(label='HDR clamp', value=False, elem_id=f"{tab}_hdr_clamp") + hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label='Range', elem_id=f"{tab}_hdr_boundary") + hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label='Threshold', elem_id=f"{tab}_hdr_threshold") + with FormRow(): + hdr_center = gr.Checkbox(label='HDR center', value=False, elem_id=f"{tab}_hdr_center") + hdr_channel_shift = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label='Channel shift', elem_id=f"{tab}_hdr_channel_shift") + hdr_full_shift = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1, label='Full shift', elem_id=f"{tab}_hdr_full_shift") + with FormRow(): + hdr_maximize = gr.Checkbox(label='HDR maximize', value=False, elem_id=f"{tab}_hdr_maximize") + hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label='Center', elem_id=f"{tab}_hdr_max_center") + hdr_max_boundry = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label='Range', elem_id=f"{tab}_hdr_max_boundry") + return cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry + + +def create_resize_inputs(tab, images, time_selector=False, scale_visible=True, mode=None): + dummy_component = gr.Number(visible=False, value=0) + with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id=f"{tab}_resize_group"): + with gr.Row(): + if mode is not None: + resize_mode = gr.Radio(label="Resize mode", elem_id=f"{tab}_resize_mode", choices=shared.resize_modes, type="index", value=mode, visible=False) + else: + resize_mode = gr.Radio(label="Resize mode", elem_id=f"{tab}_resize_mode", choices=shared.resize_modes, type="index", value='None') + resize_time = gr.Radio(label="Resize order", elem_id=f"{tab}_resize_order", choices=['Before', 'After'], value="Before", visible=time_selector) + with gr.Row(): + resize_name = gr.Dropdown(label="Resize method", elem_id=f"{tab}_resize_name", choices=[x.name for x in shared.sd_upscalers], value=opts.upscaler_for_img2img) + create_refresh_button(resize_name, modelloader.load_upscalers, lambda: {"choices": modelloader.load_upscalers()}, 'refresh_upscalers') + + with FormRow(visible=True) as _resize_group: + with gr.Column(elem_id=f"{tab}_column_size"): + selected_scale_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated + with gr.Tabs(): + with gr.Tab(label="Resize to") as tab_scale_to: + with FormRow(): + with gr.Column(elem_id=f"{tab}_column_size"): + with FormRow(): + width = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512, elem_id=f"{tab}_width") + height = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512, elem_id=f"{tab}_height") + res_switch_btn = ToolButton(value=ui_symbols.switch, elem_id=f"{tab}_res_switch_btn") + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) + detect_image_size_btn = ToolButton(value=ui_symbols.detect, elem_id=f"{tab}_detect_image_size_btn") + detect_image_size_btn.click(fn=lambda w, h, _: (w or gr.update(), h or gr.update()), _js="currentImg2imgSourceResolution", inputs=[dummy_component, dummy_component, dummy_component], outputs=[width, height], show_progress=False) + + with gr.Tab(label="Resize by") as tab_scale_by: + scale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Scale", value=1.0, elem_id=f"{tab}_scale") + if scale_visible: + with FormRow(): + scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id=f"{tab}_scale_resolution_preview") + gr.Slider(label="Unused", elem_id=f"{tab}_unused_scale_by_slider") + button_update_resize_to = gr.Button(visible=False, elem_id=f"{tab}_update_resize_to") + + on_change_args = dict(fn=resize_from_to_html, _js="currentImg2imgSourceResolution", inputs=[dummy_component, dummy_component, scale_by], outputs=scale_by_html, show_progress=False) + scale_by.release(**on_change_args) + button_update_resize_to.click(**on_change_args) + + for component in images: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) + tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) + # resize_mode.change(fn=lambda x: gr.update(visible=x != 0), inputs=[resize_mode], outputs=[_resize_group]) + return resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time def connect_clear_prompt(button): # pylint: disable=unused-argument pass -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): + def copy_seed(gen_info_string: str, index: int): res = -1 try: gen_info = json.loads(gen_info_string) + log.debug(f'Reuse: info={gen_info}') index -= gen_info.get('index_of_first_image', 0) + index = int(index) if is_subseed and gen_info.get('subseed_strength', 0) > 0: all_subseeds = gen_info.get('all_subseeds', [-1]) @@ -198,32 +284,33 @@ def copy_seed(gen_info_string: str, index): log.error(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] + dummy_component = gr.Number(visible=False, value=0) reuse_seed.click(fn=copy_seed, _js="(x, y) => [x, selected_gallery_index()]", show_progress=False, inputs=[generation_info, dummy_component], outputs=[seed, dummy_component]) def update_token_counter(text, steps): + from modules import extra_networks, sd_hijack try: text, _ = extra_networks.parse_prompt(text) _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console prompt_schedules = [[[steps, text]]] + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - if modules.shared.backend == modules.shared.Backend.ORIGINAL: + if shared.backend == shared.Backend.ORIGINAL: token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - elif modules.shared.backend == modules.shared.Backend.DIFFUSERS: - if modules.shared.sd_model is not None and hasattr(modules.shared.sd_model, 'tokenizer'): - tokenizer = modules.shared.sd_model.tokenizer + elif shared.backend == shared.Backend.DIFFUSERS: + if shared.sd_model is not None and hasattr(shared.sd_model, 'tokenizer'): + tokenizer = shared.sd_model.tokenizer if tokenizer is None: token_count = 0 max_length = 75 else: has_bos_token = tokenizer.bos_token_id is not None has_eos_token = tokenizer.eos_token_id is not None - ids = [modules.shared.sd_model.tokenizer(prompt) for prompt in prompts] + ids = [shared.sd_model.tokenizer(prompt) for prompt in prompts] if len(ids) > 0 and hasattr(ids[0], 'input_ids'): ids = [x.input_ids for x in ids] token_count = max([len(x) for x in ids]) - int(has_bos_token) - int(has_eos_token) @@ -234,8 +321,9 @@ def update_token_counter(text, steps): return f"{token_count}/{max_length}" -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" +def create_toprow(is_img2img: bool = False, id_part: str = None): + if id_part is None: + id_part = "img2img" if is_img2img else "txt2img" with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): with gr.Row(): @@ -250,18 +338,18 @@ def create_toprow(is_img2img): button_deepbooru = None if is_img2img: with gr.Column(scale=1, elem_classes="interrogate-col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id=f"{id_part}_interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id=f"{id_part}_deepbooru") with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): with gr.Row(elem_id=f"{id_part}_generate_box"): submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') with gr.Row(elem_id=f"{id_part}_generate_line2"): interrupt = gr.Button('Stop', elem_id=f"{id_part}_interrupt") - interrupt.click(fn=lambda: modules.shared.state.interrupt(), _js="requestInterrupt", inputs=[], outputs=[]) + interrupt.click(fn=lambda: shared.state.interrupt(), _js="requestInterrupt", inputs=[], outputs=[]) skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - skip.click(fn=lambda: modules.shared.state.skip(), inputs=[], outputs=[]) + skip.click(fn=lambda: shared.state.skip(), inputs=[], outputs=[]) pause = gr.Button('Pause', elem_id=f"{id_part}_pause") - pause.click(fn=lambda: modules.shared.state.pause(), _js='checkPaused', inputs=[], outputs=[]) + pause.click(fn=lambda: shared.state.pause(), _js='checkPaused', inputs=[], outputs=[]) with gr.Row(elem_id=f"{id_part}_tools"): button_paste = gr.Button(value='Restore', variant='secondary', elem_id=f"{id_part}_paste") # symbols.paste button_clear = gr.Button(value='Clear', variant='secondary', elem_id=f"{id_part}_clear_prompt_btn") # symbols.clear @@ -273,14 +361,15 @@ def create_toprow(is_img2img): negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") with gr.Row(elem_id=f"{id_part}_styles_row"): - prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[style.name for style in modules.shared.prompt_styles.styles.values()], value=[], multiselect=True) - prompt_styles_btn_refresh = ToolButton(symbols.refresh, elem_id=f"{id_part}_styles_refresh", visible=True) - prompt_styles_btn_refresh.click(fn=lambda: gr.update(choices=[style.name for style in modules.shared.prompt_styles.styles.values()]), inputs=[], outputs=[prompt_styles]) - prompt_styles_btn_select = gr.Button('Select', elem_id=f"{id_part}_styles_select", visible=False) - prompt_styles_btn_select.click(_js="applyStyles", fn=parse_style, inputs=[prompt_styles], outputs=[prompt_styles]) - prompt_styles_btn_apply = ToolButton(symbols.apply, elem_id=f"{id_part}_extra_apply", visible=False) - prompt_styles_btn_apply.click(fn=apply_styles, inputs=[prompt, negative_prompt, prompt_styles], outputs=[prompt, negative_prompt, prompt_styles]) - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, button_paste, button_extra, token_counter, token_button, negative_token_counter, negative_token_button + styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[style.name for style in shared.prompt_styles.styles.values()], value=[], multiselect=True) + _styles_btn_refresh = create_refresh_button(styles, shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"{id_part}_styles_refresh") + # styles_btn_refresh = ToolButton(symbols.refresh, elem_id=f"{id_part}_styles_refresh", visible=True) + # styles_btn_refresh.click(fn=lambda: gr.update(choices=[style.name for style in shared.prompt_styles.styles.values()]), inputs=[], outputs=[styles]) + styles_btn_select = gr.Button('Select', elem_id=f"{id_part}_styles_select", visible=False) + styles_btn_select.click(_js="applyStyles", fn=parse_style, inputs=[styles], outputs=[styles]) + styles_btn_apply = ToolButton(ui_symbols.apply, elem_id=f"{id_part}_extra_apply", visible=False) + styles_btn_apply.click(fn=apply_styles, inputs=[prompt, negative_prompt, styles], outputs=[prompt, negative_prompt, styles]) + return prompt, styles, negative_prompt, submit, button_interrogate, button_deepbooru, button_paste, button_extra, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument @@ -290,7 +379,7 @@ def setup_progressbar(*args, **kwargs): # pylint: disable=unused-argument def apply_setting(key, value): if value is None: return gr.update() - if modules.shared.cmd_opts.freeze: + if shared.cmd_opts.freeze: return gr.update() # dont allow model to be swapped when model hash exists in prompt if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: @@ -309,7 +398,7 @@ def apply_setting(key, value): opts.data[key] = valtype(value) if valtype != type(None) else value if oldval != value and opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - opts.save(modules.shared.config_filename) + opts.save(shared.config_filename) return getattr(opts, key) @@ -322,18 +411,19 @@ def set_sampler_original_options(sampler_options, sampler_algo): opts.data['schedulers_brownian_noise'] = 'brownian noise' in sampler_options opts.data['schedulers_discard_penultimate'] = 'discard penultimate sigma' in sampler_options opts.data['schedulers_sigma'] = sampler_algo - opts.save(modules.shared.config_filename, silent=True) + opts.save(shared.config_filename, silent=True) def set_sampler_diffuser_options(sampler_options): opts.data['schedulers_use_karras'] = 'karras' in sampler_options opts.data['schedulers_use_thresholding'] = 'dynamic thresholding' in sampler_options opts.data['schedulers_use_loworder'] = 'low order' in sampler_options - opts.save(modules.shared.config_filename, silent=True) + opts.data['schedulers_rescale_betas'] = 'rescale beta' in sampler_options + opts.save(shared.config_filename, silent=True) with FormRow(elem_classes=['flex-break']): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value='Default', type="index") steps = gr.Slider(minimum=1, maximum=99, step=1, label="Sampling steps", elem_id=f"{tabname}_steps", value=20) - if modules.shared.backend == modules.shared.Backend.ORIGINAL: + if shared.backend == shared.Backend.ORIGINAL: with FormRow(elem_classes=['flex-break']): choices = ['brownian noise', 'discard penultimate sigma'] values = [] @@ -347,16 +437,56 @@ def set_sampler_diffuser_options(sampler_options): sampler_algo.change(fn=set_sampler_original_options, inputs=[sampler_options, sampler_algo], outputs=[]) else: with FormRow(elem_classes=['flex-break']): - choices = ['karras', 'dynamic thresholding', 'low order'] + choices = ['karras', 'dynamic threshold', 'low order', 'rescale beta'] values = [] values += ['karras'] if opts.data.get('schedulers_use_karras', True) else [] - values += ['dynamic thresholding'] if opts.data.get('schedulers_use_thresholding', False) else [] + values += ['dynamic threshold'] if opts.data.get('schedulers_use_thresholding', False) else [] values += ['low order'] if opts.data.get('schedulers_use_loworder', True) else [] + values += ['rescale beta'] if opts.data.get('schedulers_rescale_betas', False) else [] sampler_options = gr.CheckboxGroup(label='Sampler options', choices=choices, value=values, type='value') sampler_options.change(fn=set_sampler_diffuser_options, inputs=[sampler_options], outputs=[]) return steps, sampler_index +def create_sampler_inputs(tab): + from modules import sd_samplers + with gr.Accordion(open=False, label="Sampler", elem_id=f"{tab}_sampler", elem_classes=["small-accordion"]): + with FormRow(elem_id=f"{tab}_row_sampler"): + sd_samplers.set_samplers() + steps, sampler_index = create_sampler_and_steps_selection(sd_samplers.samplers, tab) + return steps, sampler_index + + +def create_hires_inputs(tab): + with gr.Accordion(open=False, label="Second pass", elem_id=f"{tab}_second_pass", elem_classes=["small-accordion"]): + with FormGroup(): + with FormRow(elem_id=f"{tab}_hires_row1"): + enable_hr = gr.Checkbox(label='Enable second pass', value=False, elem_id=f"{tab}_enable_hr") + with FormRow(elem_id=f"{tab}_hires_row2"): + latent_index = gr.Dropdown(label='Secondary sampler', elem_id=f"{tab}_sampling_alt", choices=[x.name for x in modules.sd_samplers.samplers], value='Default', type="index") + denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.5, elem_id=f"{tab}_denoising_strength") + with FormRow(elem_id=f"{tab}_hires_finalres", variant="compact"): + hr_final_resolution = FormHTML(value="", elem_id=f"{tab}_hr_finalres", label="Upscaled resolution", interactive=False) + with FormRow(elem_id=f"{tab}_hires_fix_row1", variant="compact"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id=f"{tab}_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_force = gr.Checkbox(label='Force Hires', value=False, elem_id=f"{tab}_hr_force") + with FormRow(elem_id=f"{tab}_hires_fix_row2", variant="compact"): + hr_second_pass_steps = gr.Slider(minimum=0, maximum=99, step=1, label='Hires steps', elem_id=f"{tab}_steps_alt", value=20) + hr_scale = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Upscale by", value=2.0, elem_id=f"{tab}_hr_scale") + with FormRow(elem_id=f"{tab}_hires_fix_row3", variant="compact"): + hr_resize_x = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize width to", value=0, elem_id=f"{tab}_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize height to", value=0, elem_id=f"{tab}_hr_resize_y") + with FormGroup(visible=shared.backend == shared.Backend.DIFFUSERS): + with FormRow(elem_id=f"{tab}_refiner_row1", variant="compact"): + refiner_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Refiner start', value=0.8, elem_id=f"{tab}_refiner_start") + refiner_steps = gr.Slider(minimum=0, maximum=99, step=1, label="Refiner steps", elem_id=f"{tab}_refiner_steps", value=5) + with FormRow(elem_id=f"{tab}_refiner_row3", variant="compact"): + refiner_prompt = gr.Textbox(value='', label='Secondary prompt', elem_id=f"{tab}_refiner_prompt") + with FormRow(elem_id="txt2img_refiner_row4", variant="compact"): + refiner_negative = gr.Textbox(value='', label='Secondary negative prompt', elem_id=f"{tab}_refiner_neg_prompt") + return enable_hr, latent_index, denoising_strength, hr_final_resolution, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, refiner_start, refiner_prompt, refiner_negative + + def get_value_for_setting(key): value = getattr(opts, key) info = opts.data_labels[key] @@ -366,31 +496,34 @@ def get_value_for_setting(key): def ordered_ui_categories(): - return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # TODO: a1111 compatibility item, not implemented + return ['dimensions', 'sampler', 'seed', 'denoising', 'cfg', 'checkboxes', 'accordions', 'override_settings', 'scripts'] # a1111 compatibility item, not implemented -def create_override_settings_dropdown(tabname, row): # pylint: disable=unused-argument - dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) - dropdown.change(fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), inputs=[dropdown], outputs=[dropdown]) - return dropdown +def create_override_inputs(tab): # pylint: disable=unused-argument + with FormRow(elem_id=f"{tab}_override_settings_row"): + override_settings = gr.Dropdown([], value=None, label="Override settings", visible=False, elem_id=f"{tab}_override_settings", multiselect=True) + override_settings.change(fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), inputs=[override_settings], outputs=[override_settings]) + return override_settings def create_ui(startup_timer = None): if startup_timer is None: timer.startup = timer.Timer() reload_javascript() - parameters_copypaste.reset() + generation_parameters_copypaste.reset() import modules.txt2img # pylint: disable=redefined-outer-name modules.scripts.scripts_current = modules.scripts.scripts_txt2img modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _interrogate, _deepbooru, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) - dummy_component = gr.Label(visible=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, txt2img_submit, _interrogate, _deepbooru, txt2img_paste, txt2img_extra_networks_button, txt2img_token_counter, txt2img_token_button, txt2img_negative_token_counter, txt2img_negative_token_button = create_toprow(is_img2img=False, id_part="txt2img") + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + txt_prompt_img.change(fn=modules.images.image_data, inputs=[txt_prompt_img], outputs=[txt2img_prompt, txt_prompt_img]) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks_ui: from modules import ui_extra_networks - extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, extra_networks_button, 'txt2img', skip_indexing=opts.extra_network_skip_indexing) + extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, txt2img_extra_networks_button, 'txt2img', skip_indexing=opts.extra_network_skip_indexing) timer.startup.record('ui-extra-networks') with gr.Row(elem_id="txt2img_interface", equal_height=False): @@ -399,80 +532,19 @@ def create_ui(startup_timer = None): with FormRow(): width = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="txt2img_width") height = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="txt2img_height") - res_switch_btn = ToolButton(value=symbols.switch, elem_id="txt2img_res_switch_btn", label="Switch dims") + res_switch_btn = ToolButton(value=ui_symbols.switch, elem_id="txt2img_res_switch_btn", label="Switch dims") + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) with FormGroup(elem_classes="settings-accordion"): - with gr.Accordion(open=False, label="Sampler", elem_id="txt2img_sampler", elem_classes=["small-accordion"]): - with FormRow(elem_id="txt2img_row_sampler"): - modules.sd_samplers.set_samplers() - steps, sampler_index = create_sampler_and_steps_selection(modules.sd_samplers.samplers, "txt2img") - - with gr.Accordion(open=False, label="Batch", elem_id="txt2img_batch", elem_classes=["small-accordion"]): - with FormRow(elem_id="txt2img_row_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=32, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - batch_switch_btn = ToolButton(value=symbols.switch, elem_id="txt2img_batch_switch_btn", label="Switch dims") + steps, sampler_index = create_sampler_inputs('txt2img') + batch_count, batch_size = create_batch_inputs('txt2img') seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs('txt2img') + cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry = create_advanced_inputs('txt2img') + enable_hr, latent_index, denoising_strength, hr_final_resolution, hr_upscaler, hr_force, hr_second_pass_steps, hr_scale, hr_resize_x, hr_resize_y, refiner_steps, refiner_start, refiner_prompt, refiner_negative = create_hires_inputs('txt2img') + override_settings = create_override_inputs('txt2img') - with gr.Accordion(open=False, label="Advanced", elem_id="txt2img_advanced", elem_classes=["small-accordion"]): - with gr.Group(): - with FormRow(): - cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='CFG scale', value=6.0, elem_id="txt2img_cfg_scale") - clip_skip = gr.Slider(label='CLIP skip', value=1, minimum=1, maximum=14, step=1, elem_id='txt2img_clip_skip', interactive=True) - with FormRow(): - image_cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='Secondary CFG scale', value=6.0, elem_id="txt2img_image_cfg_scale") - diffusers_guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Guidance rescale', value=0.7, elem_id="txt2img_image_cfg_rescale") - with gr.Group(): - with FormRow(): - full_quality = gr.Checkbox(label='Full quality', value=True, elem_id="txt2img_full_quality") - restore_faces = gr.Checkbox(label='Face restore', value=False, visible=len(modules.shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - with gr.Group(): - with FormRow(): - hdr_clamp = gr.Checkbox(label='HDR clamp', value=False, elem_id="txt2img_hdr_clamp") - hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label='Range', elem_id="txt2img_hdr_boundary") - hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label='Threshold', elem_id="txt2img_hdr_threshold") - with FormRow(): - hdr_center = gr.Checkbox(label='HDR center', value=False, elem_id="txt2img_hdr_center") - hdr_channel_shift = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label='Channel shift', elem_id="txt2img_hdr_channel_shift") - hdr_full_shift = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1, label='Full shift', elem_id="txt2img_hdr_full_shift") - with FormRow(): - hdr_maximize = gr.Checkbox(label='HDR maximize', value=False, elem_id="txt2img_hdr_maximize") - hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label='Center', elem_id="txt2img_hdr_max_center") - hdr_max_boundry = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label='Range', elem_id="txt2img_hdr_max_boundry") - - with gr.Accordion(open=False, label="Second pass", elem_id="txt2img_second_pass", elem_classes=["small-accordion"]): - with FormGroup(): - with FormRow(elem_id="sampler_selection_txt2img_alt_row1"): - enable_hr = gr.Checkbox(label='Enable second pass', value=False, elem_id="txt2img_enable_hr") - with FormRow(elem_id="sampler_selection_txt2img_alt_row1"): - latent_index = gr.Dropdown(label='Secondary sampler', elem_id="txt2img_sampling_alt", choices=[x.name for x in modules.sd_samplers.samplers], value='Default', type="index") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Denoising strength', value=0.5, elem_id="txt2img_denoising_strength") - with FormRow(elem_id="txt2img_hires_finalres", variant="compact"): - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*modules.shared.latent_upscale_modes, *[x.name for x in modules.shared.sd_upscalers]], value=modules.shared.latent_upscale_default_mode) - hr_force = gr.Checkbox(label='Force Hires', value=False, elem_id="txt2img_hr_force") - with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): - hr_second_pass_steps = gr.Slider(minimum=0, maximum=99, step=1, label='Hires steps', elem_id="txt2img_steps_alt", value=20) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact"): - hr_resize_x = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=4096, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormGroup(visible=modules.shared.backend == modules.shared.Backend.DIFFUSERS): - with FormRow(elem_id="txt2img_refiner_row1", variant="compact"): - refiner_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Refiner start', value=0.8, elem_id="txt2img_refiner_start") - refiner_steps = gr.Slider(minimum=0, maximum=99, step=1, label="Refiner steps", elem_id="txt2img_refiner_steps", value=5) - with FormRow(elem_id="txt2img_refiner_row3", variant="compact"): - refiner_prompt = gr.Textbox(value='', label='Secondary Prompt') - with FormRow(elem_id="txt2img_refiner_row4", variant="compact"): - refiner_negative = gr.Textbox(value='', label='Secondary negative prompt') - - with FormRow(elem_id="txt2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('txt2img', row) - - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + txt2img_script_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [width, height, hr_scale, hr_resize_x, hr_resize_y, hr_upscaler] for preview_input in hr_resolution_preview_inputs: @@ -484,47 +556,44 @@ def create_ui(startup_timer = None): show_progress=False, ) - txt2img_gallery, generation_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("txt2img") - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( + txt2img_gallery, txt2img_generation_info, txt2img_html_info, _txt2img_html_info_formatted, txt2img_html_log = ui_common.create_output_panel("txt2img") + connect_reuse_seed(seed, reuse_seed, txt2img_generation_info, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, txt2img_generation_info, is_subseed=True) + + global txt2img_args # pylint: disable=global-statement + dummy_component = gr.Textbox(visible=False, value='dummy') + txt2img_args = [ + dummy_component, + txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_styles, + steps, sampler_index, latent_index, + full_quality, restore_faces, tiling, + batch_count, batch_size, + cfg_scale, image_cfg_scale, diffusers_guidance_rescale, + clip_skip, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + height, width, + enable_hr, denoising_strength, + hr_scale, hr_upscaler, hr_force, hr_second_pass_steps, hr_resize_x, hr_resize_y, + refiner_steps, refiner_start, refiner_prompt, refiner_negative, + hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + override_settings, + ] + txt2img_dict = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit_txt2img", - inputs=[ - dummy_component, - txt2img_prompt, txt2img_negative_prompt, - txt2img_prompt_styles, - steps, - sampler_index, latent_index, - full_quality, restore_faces, tiling, - batch_count, batch_size, - cfg_scale, image_cfg_scale, - diffusers_guidance_rescale, - clip_skip, - seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, - height, width, - enable_hr, denoising_strength, - hr_scale, hr_upscaler, hr_force, hr_second_pass_steps, hr_resize_x, hr_resize_y, - refiner_steps, refiner_start, refiner_prompt, refiner_negative, - hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, - override_settings, - ] + custom_inputs, + inputs=txt2img_args + txt2img_script_inputs, outputs=[ txt2img_gallery, - generation_info, - html_info, - html_log, + txt2img_generation_info, + txt2img_html_info, + txt2img_html_log, ], show_progress=False, ) - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) - batch_switch_btn.click(lambda w, h: (h, w), inputs=[batch_count, batch_size], outputs=[batch_count, batch_size], show_progress=False) - txt_prompt_img.change(fn=modules.images.image_data, inputs=[txt_prompt_img], outputs=[txt2img_prompt, txt_prompt_img]) + txt2img_prompt.submit(**txt2img_dict) + txt2img_submit.click(**txt2img_dict) + global txt2img_paste_fields # pylint: disable=global-statement txt2img_paste_fields = [ # prompt (txt2img_prompt, "Prompt"), @@ -570,11 +639,12 @@ def create_ui(startup_timer = None): (seed_resize_from_h, "Seed resize from-2"), *modules.scripts.scripts_txt2img.infotext_fields ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) - parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None)) + generation_parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) + txt2img_bindings = generation_parameters_copypaste.ParamBinding(paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None) + generation_parameters_copypaste.register_paste_params_button(txt2img_bindings) - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + txt2img_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[txt2img_token_counter]) + txt2img_negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[txt2img_negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) @@ -584,12 +654,12 @@ def create_ui(startup_timer = None): modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_paste, img2img_extra_networks_button, img2img_token_counter, img2img_token_button, img2img_negative_token_counter, img2img_negative_token_button = create_toprow(is_img2img=True, id_part="img2img") img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks_ui: from modules import ui_extra_networks - extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks_ui, extra_networks_button, 'img2img', skip_indexing=opts.extra_network_skip_indexing) + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks_ui, img2img_extra_networks_button, 'img2img', skip_indexing=opts.extra_network_skip_indexing) with FormRow(elem_id="img2img_interface", equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -643,16 +713,16 @@ def update_orig(image, state): init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if modules.shared.cmd_opts.hide_ui_dir_config else '' + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML( "

Upload images or process images in a directory" + "
Add inpaint batch mask directory to enable inpaint batch processing" f"{hidden}

" ) img2img_batch_files = gr.Files(label="Batch Process", interactive=True, elem_id="img2img_image_batch") - img2img_batch_input_dir = gr.Textbox(label="Inpaint batch input directory", **modules.shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Inpaint batch output directory", **modules.shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **modules.shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + img2img_batch_input_dir = gr.Textbox(label="Inpaint batch input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Inpaint batch output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] for i, tab in enumerate(img2img_tabs): @@ -663,78 +733,18 @@ def update_orig(image, state): button.click(fn=lambda: None, _js=f"switch_to_{name.replace(' ', '_')}", inputs=[], outputs=[]) with FormGroup(elem_classes="settings-accordion"): - with gr.Accordion(open=False, label="Sampler", elem_classes=["small-accordion"], elem_id="img2img_sampling_group"): - modules.sd_samplers.set_samplers() - steps, sampler_index = create_sampler_and_steps_selection(modules.sd_samplers.samplers_for_img2img, "img2img") - - with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id="img2img_resize_group"): - with gr.Row(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["None", "Resize fixed", "Crop and resize", "Resize and fill", "Latent upscale"], type="index", value="None") - - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated - - with gr.Tabs(): - with gr.Tab(label="Resize to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - with FormRow(): - width = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_column_dim", scale=1, elem_classes="dimensions-tools"): - with FormRow(): - res_switch_btn = ToolButton(value=symbols.switch, elem_id="img2img_res_switch_btn") - detect_image_size_btn = ToolButton(value=symbols.detect, elem_id="img2img_detect_image_size_btn") - - with gr.Tab(label="Resize by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) - tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) - - with gr.Accordion(open=False, label="Batch", elem_classes=["small-accordion"], elem_id="img2img_batch_group"): - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + steps, sampler_index = create_sampler_inputs('img2img') + resize_mode, resize_name, width, height, scale_by, selected_scale_tab, _resize_time = create_resize_inputs('img2img', [init_img, sketch]) + batch_count, batch_size = create_batch_inputs('img2img') seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs('img2img') with gr.Accordion(open=False, label="Denoise", elem_classes=["small-accordion"], elem_id="img2img_denoise_group"): with FormRow(): - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="img2img_denoising_strength") refiner_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Denoise start', value=0.0, elem_id="img2img_refiner_start") - with gr.Accordion(open=False, label="Advanced", elem_classes=["small-accordion"], elem_id="img2img_advanced_group"): - with FormRow(): - cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='CFG scale', value=6.0, elem_id="img2img_cfg_scale") - image_cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.15, label='Image CFG scale', value=1.5, elem_id="img2img_image_cfg_scale") - with FormRow(): - clip_skip = gr.Slider(label='CLIP skip', value=1, minimum=1, maximum=4, step=1, elem_id='img2img_clip_skip', interactive=True) - diffusers_guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Guidance rescale', value=0.7, elem_id="txt2img_image_cfg_rescale") - with FormRow(elem_classes="img2img_checkboxes_row", variant="compact"): - full_quality = gr.Checkbox(label='Full quality', value=True, elem_id="img2img_full_quality") - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(modules.shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry = create_advanced_inputs('img2img') with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: with FormRow(): @@ -757,73 +767,67 @@ def select_img2img_tab(tab): for i, elem in enumerate(img2img_tabs): elem.select(fn=lambda tab=i: select_img2img_tab(tab), inputs=[], outputs=[inpaint_controls, mask_alpha]) # pylint: disable=cell-var-from-loop - with FormRow(elem_id="img2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('img2img', row) + override_settings = create_override_inputs('img2img') with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("img2img") - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - img2img_args = dict( + img2img_script_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, img2img_generation_info, img2img_html_info, _img2img_html_info_formatted, img2img_html_log = ui_common.create_output_panel("img2img") + + connect_reuse_seed(seed, reuse_seed, img2img_generation_info, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, img2img_generation_info, is_subseed=True) + + img2img_prompt_img.change(fn=modules.images.image_data, inputs=[img2img_prompt_img], outputs=[img2img_prompt, img2img_prompt_img]) + dummy_component1 = gr.Textbox(visible=False, value='dummy') + dummy_component2 = gr.Number(visible=False, value=0) + global img2img_args # pylint: disable=global-statement + img2img_args = [ + dummy_component1, dummy_component2, + img2img_prompt, img2img_negative_prompt, img2img_prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps, + sampler_index, latent_index, + mask_blur, mask_alpha, + inpainting_fill, + full_quality, restore_faces, tiling, + batch_count, batch_size, + cfg_scale, image_cfg_scale, + diffusers_guidance_rescale, + refiner_steps, + refiner_start, + clip_skip, + denoising_strength, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + selected_scale_tab, + height, width, + scale_by, + resize_mode, resize_name, + inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, + img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, + hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + override_settings, + ] + img2img_dict = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", - inputs=[ - dummy_component, dummy_component, - img2img_prompt, img2img_negative_prompt, - img2img_prompt_styles, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps, - sampler_index, latent_index, - mask_blur, mask_alpha, - inpainting_fill, - full_quality, restore_faces, tiling, - batch_count, batch_size, - cfg_scale, image_cfg_scale, - diffusers_guidance_rescale, - refiner_steps, - refiner_start, - clip_skip, - denoising_strength, - seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, - selected_scale_tab, - height, width, - scale_by, - resize_mode, - inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, - img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, - override_settings, - ] + custom_inputs, + inputs= img2img_args + img2img_script_inputs, outputs=[ img2img_gallery, - generation_info, - html_info, - html_log, + img2img_generation_info, + img2img_html_info, + img2img_html_log, ], show_progress=False, ) - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) + img2img_prompt.submit(**img2img_dict) + submit.click(**img2img_dict) + dummy_component = gr.Textbox(visible=False, value='dummy') interrogate_args = dict( _js="get_img2img_tab_index", @@ -843,20 +847,11 @@ def select_img2img_tab(tab): img2img_interrogate.click(fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args) img2img_deepbooru.click(fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args) - res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) - - detect_image_size_btn.click( - fn=lambda w, h, _: (w or gr.update(), h or gr.update()), - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, dummy_component], - outputs=[width, height], - show_progress=False, - ) - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) + img2img_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_prompt, steps], outputs=[img2img_token_counter]) + img2img_negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[img2img_negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + global img2img_paste_fields # pylint: disable=global-statement img2img_paste_fields = [ # prompt (img2img_prompt, "Prompt"), @@ -899,29 +894,40 @@ def select_img2img_tab(tab): (seed_resize_from_h, "Seed resize from-2"), *modules.scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) - parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, - )) + generation_parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) + generation_parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + img2img_bindings = generation_parameters_copypaste.ParamBinding(paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None) + generation_parameters_copypaste.register_paste_params_button(img2img_bindings) timer.startup.record("ui-img2img") modules.scripts.scripts_current = None + if shared.backend == shared.Backend.DIFFUSERS: + with gr.Blocks(analytics_enabled=False) as control_interface: + from modules import ui_control + ui_control.create_ui() + timer.startup.record("ui-control") + else: + control_interface = None + with gr.Blocks(analytics_enabled=False) as extras_interface: + from modules import ui_postprocessing ui_postprocessing.create_ui() timer.startup.record("ui-extras") with gr.Blocks(analytics_enabled=False) as train_interface: + from modules import ui_train ui_train.create_ui([txt2img_prompt, txt2img_negative_prompt, steps, sampler_index, cfg_scale, seed, width, height]) timer.startup.record("ui-train") with gr.Blocks(analytics_enabled=False) as models_interface: + from modules import ui_models ui_models.create_ui() timer.startup.record("ui-models") with gr.Blocks(analytics_enabled=False) as interrogate_interface: + from modules import ui_interrogate ui_interrogate.create_ui() timer.startup.record("ui-interrogate") @@ -988,7 +994,8 @@ def get_opt_values(): loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config) components = [] component_dict = {} - modules.shared.settings_components = component_dict + shared.settings_components = component_dict + dummy_component1 = gr.Label(visible=False) script_callbacks.ui_settings_callback() opts.reorder() @@ -1006,17 +1013,17 @@ def run_settings(*args): if cmd_opts.use_directml: directml_override_opts() if cmd_opts.use_openvino: - if not modules.shared.opts.cuda_compile: - modules.shared.log.warning("OpenVINO: Enabling Torch Compile") - modules.shared.opts.cuda_compile = True - if modules.shared.opts.cuda_compile_backend != "openvino_fx": - modules.shared.log.warning("OpenVINO: Setting Torch Compiler backend to OpenVINO FX") - modules.shared.opts.cuda_compile_backend = "openvino_fx" - if modules.shared.opts.sd_backend != "diffusers": - modules.shared.log.warning("OpenVINO: Setting backend to Diffusers") - modules.shared.opts.sd_backend = "diffusers" + if not shared.opts.cuda_compile: + shared.log.warning("OpenVINO: Enabling Torch Compile") + shared.opts.cuda_compile = True + if shared.opts.cuda_compile_backend != "openvino_fx": + shared.log.warning("OpenVINO: Setting Torch Compiler backend to OpenVINO FX") + shared.opts.cuda_compile_backend = "openvino_fx" + if shared.opts.sd_backend != "diffusers": + shared.log.warning("OpenVINO: Setting backend to Diffusers") + shared.opts.sd_backend = "diffusers" try: - opts.save(modules.shared.config_filename) + opts.save(shared.config_filename) if len(changed) > 0: log.info(f'Settings: changed={len(changed)} {changed}') except RuntimeError: @@ -1031,7 +1038,7 @@ def run_settings_single(value, key): return gr.update(value=getattr(opts, key)), opts.dumpjson() if cmd_opts.use_directml: directml_override_opts() - opts.save(modules.shared.config_filename) + opts.save(shared.config_filename) log.debug(f'Setting changed: key={key}, value={value}') return get_value_for_setting(key), opts.dumpjson() @@ -1077,7 +1084,7 @@ def run_settings_single(value, key): current_row = gr.Column(variant='compact') current_row.__enter__() previous_section = item.section - if k in quicksettings_names and not modules.shared.cmd_opts.freeze: + if k in quicksettings_names and not shared.cmd_opts.freeze: quicksettings_list.append((i, k, item)) components.append(dummy_component) elif section_must_be_skipped: @@ -1107,7 +1114,7 @@ def run_settings_single(value, key): gr.Markdown(md) with gr.TabItem("Licenses", id="system_licenses", elem_id="system_tab_licenses"): - gr.HTML(modules.shared.html("licenses.html"), elem_id="licenses", elem_classes="licenses") + gr.HTML(shared.html("licenses.html"), elem_id="licenses", elem_classes="licenses") create_dirty_indicator("tab_licenses", [], interactive=False) def unload_sd_weights(): @@ -1124,39 +1131,42 @@ def reload_sd_weights(): timer.startup.record("ui-settings") - interfaces = [ - (txt2img_interface, "Text", "txt2img"), - (img2img_interface, "Image", "img2img"), - (extras_interface, "Process", "process"), - (train_interface, "Train", "train"), - (models_interface, "Models", "models"), - (interrogate_interface, "Interrogate", "interrogate"), - ] + interfaces = [] + interfaces += [(txt2img_interface, "Text", "txt2img")] + interfaces += [(img2img_interface, "Image", "img2img")] + interfaces += [(control_interface, "Control", "control")] if control_interface is not None else [] + interfaces += [(extras_interface, "Process", "process")] + interfaces += [(interrogate_interface, "Interrogate", "interrogate")] + interfaces += [(train_interface, "Train", "train")] + interfaces += [(models_interface, "Models", "models")] interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "System", "system")] + + from modules import ui_extensions extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] timer.startup.record("ui-extensions") - modules.shared.tab_names = [] + shared.tab_names = [] for _interface, label, _ifid in interfaces: - modules.shared.tab_names.append(label) + shared.tab_names.append(label) - with gr.Blocks(theme=modules.theme.gradio_theme, analytics_enabled=False, title="SD.Next") as demo: + with gr.Blocks(theme=theme.gradio_theme, analytics_enabled=False, title="SD.Next") as demo: with gr.Row(elem_id="quicksettings", variant="compact"): for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - parameters_copypaste.connect_paste_params_buttons() + generation_parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: for interface, label, ifid in interfaces: if interface is None: continue - # if label in modules.shared.opts.hidden_tabs or label == '': + # if label in shared.opts.hidden_tabs or label == '': # continue with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): + # log.debug(f'UI render: id={ifid}') interface.render() for interface, _label, ifid in interfaces: if interface is None: @@ -1176,9 +1186,9 @@ def reload_sd_weights(): inputs=components, outputs=[text_settings, result], ) - defaults_submit.click(fn=lambda: modules.shared.restore_defaults(restart=True), _js="restartReload") - restart_submit.click(fn=lambda: modules.shared.restart_server(restart=True), _js="restartReload") - shutdown_submit.click(fn=lambda: modules.shared.restart_server(restart=False), _js="restartReload") + defaults_submit.click(fn=lambda: shared.restore_defaults(restart=True), _js="restartReload") + restart_submit.click(fn=lambda: shared.restart_server(restart=True), _js="restartReload") + shutdown_submit.click(fn=lambda: shared.restart_server(restart=False), _js="restartReload") for _i, k, _item in quicksettings_list: component = component_dict[k] @@ -1214,11 +1224,13 @@ def reload_sd_weights(): ) def reference_submit(model): - from modules import modelloader - loaded = modelloader.load_reference(model) - if loaded: + if '@' not in model: # diffusers + loaded = modelloader.load_reference(model) return model if loaded else opts.sd_model_checkpoint - return loaded + else: # civitai + model, url = model.split('@') + loaded = modelloader.load_civitai(model, url) + return loaded if loaded is not None else opts.sd_model_checkpoint button_set_reference = gr.Button('Change reference', elem_id='change_reference', visible=False) button_set_reference.click( @@ -1295,7 +1307,7 @@ def stylesheet(fn): if not os.path.isfile(cssfile): continue head += stylesheet(cssfile) - if opts.gradio_theme in modules.theme.list_builtin_themes(): + if opts.gradio_theme in theme.list_builtin_themes(): head += stylesheet(os.path.join(script_path, "javascript", f"{opts.gradio_theme}.css")) if os.path.exists(os.path.join(data_path, "user.css")): head += stylesheet(os.path.join(data_path, "user.css")) @@ -1305,13 +1317,13 @@ def stylesheet(fn): def reload_javascript(): - is_builtin = modules.theme.reload_gradio_theme() + is_builtin = theme.reload_gradio_theme() head = html_head() css = html_css(is_builtin) body = html_body() def template_response(*args, **kwargs): - res = modules.shared.GradioTemplateResponseOriginal(*args, **kwargs) + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{head}'.encode("utf8")) res.body = res.body.replace(b'', f'{css}{body}'.encode("utf8")) res.init_headers() @@ -1335,5 +1347,5 @@ def quicksettings_hint(): app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) -if not hasattr(modules.shared, 'GradioTemplateResponseOriginal'): - modules.shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse diff --git a/modules/ui_common.py b/modules/ui_common.py index d63c1ce7a..5c35e3f4c 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -13,6 +13,8 @@ folder_symbol = symbols.folder +debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: PASTE') def update_generation_info(generation_info, html_info, img_index): @@ -157,15 +159,16 @@ def __init__(self, d=None): fullfns.append(fullfn) if txt_fullfn: filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) + # fullfns.append(txt_fullfn) modules.script_callbacks.image_save_btn_callback(filename) if shared.opts.samples_save_zip and len(fullfns) > 1: zip_filepath = os.path.join(shared.opts.outdir_save, "images.zip") from zipfile import ZipFile with ZipFile(zip_filepath, "w") as zip_file: for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) + if os.path.isfile(fullfns[i]): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0] if len(filenames) > 0 else 'none'}") @@ -194,13 +197,15 @@ def open_folder(result_gallery, gallery_index = 0): subprocess.Popen(["xdg-open", path]) # pylint: disable=consider-using-with -def create_output_panel(tabname): +def create_output_panel(tabname, preview=True): import modules.generation_parameters_copypaste as parameters_copypaste with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Group(elem_id=f"{tabname}_gallery_container"): + if tabname == "txt2img": + gr.HTML(value="", elem_id="main_info", visible=False, elem_classes=["main-info"]) # columns are for <576px, <768px, <992px, <1200px, <1400px, >1400px - result_gallery = gr.Gallery(value=[], label='Output', show_label=False, show_download_button=True, allow_preview=True, elem_id=f"{tabname}_gallery", container=False, preview=True, columns=5, object_fit='scale-down', height=shared.opts.gallery_height or None) + result_gallery = gr.Gallery(value=[], label='Output', show_label=False, show_download_button=True, allow_preview=True, elem_id=f"{tabname}_gallery", container=False, preview=preview, columns=5, object_fit='scale-down', height=shared.opts.gallery_height or None) with gr.Column(elem_id=f"{tabname}_footer", elem_classes="gallery_footer"): dummy_component = gr.Label(visible=False) @@ -213,7 +218,10 @@ def create_output_panel(tabname): clip_files.click(fn=None, _js='clip_gallery_urls', inputs=[result_gallery], outputs=[]) save = gr.Button('Save', elem_id=f'save_{tabname}') delete = gr.Button('Delete', elem_id=f'delete_{tabname}') - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + if shared.backend == shared.Backend.ORIGINAL: + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + else: + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "control", "extras"]) download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): @@ -244,9 +252,9 @@ def create_output_panel(tabname): else: paste_field_names = [] for paste_tabname, paste_button in buttons.items(): - parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names - )) + debug(f'Create output panel: button={paste_button} tabname={paste_tabname}') + bindings = parameters_copypaste.ParamBinding(paste_button=paste_button, tabname=paste_tabname, source_tabname=("txt2img" if tabname == "txt2img" else None), source_image_component=result_gallery, paste_field_names=paste_field_names) + parameters_copypaste.register_paste_params_button(bindings) return result_gallery, generation_info, html_info, html_info_formatted, html_log diff --git a/modules/ui_control.py b/modules/ui_control.py new file mode 100644 index 000000000..8365d0127 --- /dev/null +++ b/modules/ui_control.py @@ -0,0 +1,618 @@ +import os +import gradio as gr +from modules.control import unit +from modules.control import processors # patrickvonplaten controlnet_aux +from modules.control.units import controlnet # lllyasviel ControlNet +from modules.control.units import xs # vislearn ControlNet-XS +from modules.control.units import lite # vislearn ControlNet-XS +from modules.control.units import t2iadapter # TencentARC T2I-Adapter +from modules.control.units import reference # reference pipeline +from modules.control.units import ipadapter # reference pipeline +from modules import errors, shared, progress, sd_samplers, ui, ui_components, ui_symbols, ui_common, generation_parameters_copypaste, call_queue +from modules.ui_components import FormRow, FormGroup + + +gr_height = 512 +max_units = 5 +units: list[unit.Unit] = [] # main state variable +input_source = None +input_init = None +debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: CONTROL') + + +def initialize(): + from modules import devices + shared.log.debug(f'Control initialize: models={shared.opts.control_dir}') + controlnet.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet') + xs.cache_dir = os.path.join(shared.opts.control_dir, 'xs') + lite.cache_dir = os.path.join(shared.opts.control_dir, 'lite') + t2iadapter.cache_dir = os.path.join(shared.opts.control_dir, 'adapter') + processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor') + unit.default_device = devices.device + unit.default_dtype = devices.dtype + os.makedirs(shared.opts.control_dir, exist_ok=True) + os.makedirs(controlnet.cache_dir, exist_ok=True) + os.makedirs(xs.cache_dir, exist_ok=True) + os.makedirs(lite.cache_dir, exist_ok=True) + os.makedirs(t2iadapter.cache_dir, exist_ok=True) + os.makedirs(processors.cache_dir, exist_ok=True) + + +def return_controls(res): + # return preview, image, video, gallery, text + debug(f'Control received: type={type(res)} {res}') + if isinstance(res, str): # error response + return [None, None, None, None, res] + elif isinstance(res, tuple): # standard response received as tuple via control_run->yield(output_images, process_image, result_txt) + preview_image = res[1] # may be None + output_image = res[0][0] if isinstance(res[0], list) else res[0] # may be image or list of images + if isinstance(res[0], list): + output_gallery = res[0] if res[0][0] is not None else [] + else: + output_gallery = [res[0]] if res[0] is not None else [] # must return list, but can receive single image + result_txt = res[2] if len(res) > 2 else '' # do we have a message + output_video = res[3] if len(res) > 3 else None # do we have a video filename + return [preview_image, output_image, output_video, output_gallery, result_txt] + else: # unexpected + return [None, None, None, None, f'Control: Unexpected response: {type(res)}'] + + +def generate_click(job_id: str, active_tab: str, *args): + from modules.control.run import control_run + shared.log.debug(f'Control: tab={active_tab} job={job_id} args={args}') + if active_tab not in ['controlnet', 'xs', 'adapter', 'reference', 'lite']: + return None, None, None, None, f'Control: Unknown mode: {active_tab} args={args}' + shared.state.begin('control') + progress.add_task_to_queue(job_id) + with call_queue.queue_lock: + yield [None, None, None, None, 'Control: starting'] + shared.mem_mon.reset() + progress.start_task(job_id) + try: + for results in control_run(units, input_source, input_init, active_tab, True, *args): + progress.record_results(job_id, results) + yield return_controls(results) + except Exception as e: + shared.log.error(f"Control exception: {e}") + errors.display(e, 'Control') + return None, None, None, None, f'Control: Exception: {e}' + progress.finish_task(job_id) + shared.state.end() + + +def display_units(num_units): + return (num_units * [gr.update(visible=True)]) + ((max_units - num_units) * [gr.update(visible=False)]) + + +def get_video(filepath: str): + try: + import cv2 + from modules.control.util import decode_fourcc + video = cv2.VideoCapture(filepath) + if not video.isOpened(): + msg = f'Control: video open failed: path="{filepath}"' + shared.log.error(msg) + return msg + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = video.get(cv2.CAP_PROP_FPS) + duration = float(frames) / fps + w, h = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + codec = decode_fourcc(video.get(cv2.CAP_PROP_FOURCC)) + video.release() + shared.log.debug(f'Control: input video: path={filepath} frames={frames} fps={fps} size={w}x{h} codec={codec}') + msg = f'Control input | Video | Size {w}x{h} | Frames {frames} | FPS {fps:.2f} | Duration {duration:.2f} | Codec {codec}' + return msg + except Exception as e: + msg = f'Control: video open failed: path={filepath} {e}' + shared.log.error(msg) + return msg + + +def select_input(selected_input, selected_init, init_type): + debug(f'Control select input: source={selected_input} init={selected_init}, type={init_type}') + global input_source, input_init # pylint: disable=global-statement + input_type = type(selected_input) + status = 'Control input | Unknown' + res = [gr.Tabs.update(selected='out-gallery'), status] + # control inputs + if hasattr(selected_input, 'size'): # image via upload -> image + input_source = [selected_input] + input_type = 'PIL.Image' + shared.log.debug(f'Control input: type={input_type} input={input_source}') + status = f'Control input | Image | Size {selected_input.width}x{selected_input.height} | Mode {selected_input.mode}' + res = [gr.Tabs.update(selected='out-gallery'), status] + elif isinstance(selected_input, gr.components.image.Image): # not likely + input_source = [selected_input.value] + input_type = 'gr.Image' + shared.log.debug(f'Control input: type={input_type} input={input_source}') + res = [gr.Tabs.update(selected='out-gallery'), status] + elif isinstance(selected_input, str): # video via upload > tmp filepath to video + input_source = selected_input + input_type = 'gr.Video' + shared.log.debug(f'Control input: type={input_type} input={input_source}') + status = get_video(input_source) + res = [gr.Tabs.update(selected='out-video'), status] + elif isinstance(selected_input, list): # batch or folder via upload -> list of tmp filepaths + if hasattr(selected_input[0], 'name'): + input_type = 'tempfiles' + input_source = [f.name for f in selected_input] # tempfile + else: + input_type = 'files' + input_source = selected_input + status = f'Control input | Images | Files {len(input_source)}' + shared.log.debug(f'Control input: type={input_type} input={input_source}') + res = [gr.Tabs.update(selected='out-gallery'), status] + else: # unknown + input_source = None + # init inputs: optional + if init_type == 0: # Control only + input_init = None + elif init_type == 1: # Init image same as control assigned during runtime + input_init = None + elif init_type == 2: # Separate init image + if hasattr(selected_init, 'size'): # image via upload -> image + input_init = [selected_init] + input_type = 'PIL.Image' + shared.log.debug(f'Control input: type={input_type} input={input_init}') + status = f'Control input | Image | Size {selected_init.width}x{selected_init.height} | Mode {selected_init.mode}' + res = [gr.Tabs.update(selected='out-gallery'), status] + elif isinstance(selected_init, gr.components.image.Image): # not likely + input_init = [selected_init.value] + input_type = 'gr.Image' + shared.log.debug(f'Control input: type={input_type} input={input_init}') + res = [gr.Tabs.update(selected='out-gallery'), status] + elif isinstance(selected_init, str): # video via upload > tmp filepath to video + input_init = selected_init + input_type = 'gr.Video' + shared.log.debug(f'Control input: type={input_type} input={input_init}') + status = get_video(input_init) + res = [gr.Tabs.update(selected='out-video'), status] + elif isinstance(selected_init, list): # batch or folder via upload -> list of tmp filepaths + if hasattr(selected_init[0], 'name'): + input_type = 'tempfiles' + input_init = [f.name for f in selected_init] # tempfile + else: + input_type = 'files' + input_init = selected_init + status = f'Control input | Images | Files {len(input_init)}' + shared.log.debug(f'Control input: type={input_type} input={input_init}') + res = [gr.Tabs.update(selected='out-gallery'), status] + else: # unknown + input_init = None + debug(f'Control select input: source={input_source} init={input_init}') + return res + + +def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + + +def create_ui(_blocks: gr.Blocks=None): + initialize() + if shared.backend == shared.Backend.ORIGINAL: + with gr.Blocks(analytics_enabled = False) as control_ui: + pass + return [(control_ui, 'Control', 'control')] + + with gr.Blocks(analytics_enabled = False) as control_ui: + prompt, styles, negative, btn_generate, _btn_interrogate, _btn_deepbooru, btn_paste, btn_extra, prompt_counter, btn_prompt_counter, negative_counter, btn_negative_counter = ui.create_toprow(is_img2img=False, id_part='control') + with FormGroup(elem_id="control_interface", equal_height=False): + with gr.Row(elem_id='control_settings'): + + with gr.Accordion(open=False, label="Input", elem_id="control_input", elem_classes=["small-accordion"]): + with gr.Row(): + show_ip = gr.Checkbox(label="Enable IP adapter", value=False, elem_id="control_show_ip") + with gr.Row(): + show_preview = gr.Checkbox(label="Show preview", value=False, elem_id="control_show_preview") + with gr.Row(): + input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type') + with gr.Row(): + denoising_strength = gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="control_denoising_strength") + + resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time = ui.create_resize_inputs('control', [], time_selector=True, scale_visible=False, mode='Fixed') + + with gr.Accordion(open=False, label="Sampler", elem_id="control_sampler", elem_classes=["small-accordion"]): + sd_samplers.set_samplers() + steps, sampler_index = ui.create_sampler_and_steps_selection(sd_samplers.samplers, "control") + + batch_count, batch_size = ui.create_batch_inputs('control') + seed, _reuse_seed, subseed, _reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = ui.create_seed_inputs('control', reuse_visible=False) + cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry = ui.create_advanced_inputs('control') + + with gr.Accordion(open=False, label="Video", elem_id="control_video", elem_classes=["small-accordion"]): + with gr.Row(): + video_skip_frames = gr.Slider(minimum=0, maximum=100, step=1, label='Skip input frames', value=0, elem_id="control_video_skip_frames") + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + video_duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + video_loop = gr.Checkbox(label='Loop', value=True, visible=False) + video_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + video_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[video_duration, video_loop, video_pad, video_interpolate]) + + override_settings = ui.create_override_inputs('control') + + with FormRow(variant='compact', elem_id="control_extra_networks", visible=False) as extra_networks_ui: + from modules import timer, ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, btn_extra, 'control', skip_indexing=shared.opts.extra_network_skip_indexing) + timer.startup.record('ui-extra-networks') + + with gr.Row(elem_id='control_status'): + result_txt = gr.HTML(elem_classes=['control-result'], elem_id='control-result') + + with gr.Row(elem_id='control-inputs'): + with gr.Column(scale=9, elem_id='control-input-column', visible=True) as _column_input: + gr.HTML('Control input

') + with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-input'): + with gr.Tab('Image', id='in-image') as tab_image: + input_image = gr.Image(label="Input", show_label=False, type="pil", source="upload", interactive=True, tool="editor", height=gr_height) + with gr.Tab('Video', id='in-video') as tab_video: + input_video = gr.Video(label="Input", show_label=False, interactive=True, height=gr_height) + with gr.Tab('Batch', id='in-batch') as tab_batch: + input_batch = gr.File(label="Input", show_label=False, file_count='multiple', file_types=['image'], type='file', interactive=True, height=gr_height) + with gr.Tab('Folder', id='in-folder') as tab_folder: + input_folder = gr.File(label="Input", show_label=False, file_count='directory', file_types=['image'], type='file', interactive=True, height=gr_height) + with gr.Column(scale=9, elem_id='control-init-column', visible=False) as column_init: + gr.HTML('Init input

') + with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-init'): + with gr.Tab('Image', id='init-image') as tab_image_init: + init_image = gr.Image(label="Input", show_label=False, type="pil", source="upload", interactive=True, tool="editor", height=gr_height) + with gr.Tab('Video', id='init-video') as tab_video_init: + init_video = gr.Video(label="Input", show_label=False, interactive=True, height=gr_height) + with gr.Tab('Batch', id='init-batch') as tab_batch_init: + init_batch = gr.File(label="Input", show_label=False, file_count='multiple', file_types=['image'], type='file', interactive=True, height=gr_height) + with gr.Tab('Folder', id='init-folder') as tab_folder_init: + init_folder = gr.File(label="Input", show_label=False, file_count='directory', file_types=['image'], type='file', interactive=True, height=gr_height) + with gr.Column(scale=9, elem_id='control-init-column', visible=False) as column_ip: + gr.HTML('IP Adapter

') + with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-ip'): + with gr.Tab('Image', id='init-image') as tab_image_init: + ip_image = gr.Image(label="Input", show_label=False, type="pil", source="upload", interactive=True, tool="editor", height=gr_height) + with gr.Row(): + ip_adapter = gr.Dropdown(label='Adapter', choices=ipadapter.ADAPTERS, value='none') + ip_scale = gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=0.5) + with gr.Row(): + ip_type = gr.Radio(label="Input type", choices=['Init image same as control', 'Separate init image'], value='Init image same as control', type='index', elem_id='control_ip_type') + ip_image.change(fn=lambda x: gr.update(value='Init image same as control' if x is None else 'Separate init image'), inputs=[ip_image], outputs=[ip_type]) + with gr.Column(scale=9, elem_id='control-output-column', visible=True) as _column_output: + gr.HTML('Output

') + with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-output') as output_tabs: + with gr.Tab('Gallery', id='out-gallery'): + output_gallery, _output_gen_info, _output_html_info, _output_html_info_formatted, _output_html_log = ui_common.create_output_panel("control", preview=True) + with gr.Tab('Image', id='out-image'): + output_image = gr.Image(label="Input", show_label=False, type="pil", interactive=False, tool="editor", height=gr_height) + with gr.Tab('Video', id='out-video'): + output_video = gr.Video(label="Input", show_label=False, height=gr_height) + with gr.Column(scale=9, elem_id='control-preview-column', visible=False) as column_preview: + gr.HTML('Preview

') + with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-preview'): + with gr.Tab('Preview', id='preview-image') as tab_image: + preview_process = gr.Image(label="Input", show_label=False, type="pil", source="upload", interactive=False, height=gr_height, visible=True) + + for ctrl in [input_image, input_video, input_batch, input_folder, init_image, init_video, init_batch, init_folder, tab_image, tab_video, tab_batch, tab_folder, tab_image_init, tab_video_init, tab_batch_init, tab_folder_init]: + inputs = [input_image, init_image, input_type] + outputs = [output_tabs, result_txt] + if hasattr(ctrl, 'change'): + ctrl.change(fn=select_input, inputs=inputs, outputs=outputs) + if hasattr(ctrl, 'select'): + ctrl.select(fn=select_input, inputs=inputs, outputs=outputs) + show_preview.change(fn=lambda x: gr.update(visible=x), inputs=[show_preview], outputs=[column_preview]) + show_ip.change(fn=lambda x: gr.update(visible=x), inputs=[show_ip], outputs=[column_ip]) + input_type.change(fn=lambda x: gr.update(visible=x == 2), inputs=[input_type], outputs=[column_init]) + + with gr.Tabs(elem_id='control-tabs') as _tabs_control_type: + + with gr.Tab('ControlNet') as _tab_controlnet: + gr.HTML('ControlNet') + with gr.Row(): + extra_controls = [ + gr.Checkbox(label="Guess mode", value=False, scale=3), + ] + num_controlnet_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1) + controlnet_ui_units = [] # list of hidable accordions + for i in range(max_units): + with gr.Accordion(f'Control unit {i+1}', visible= i < num_controlnet_units.value) as unit_ui: + with gr.Row(): + with gr.Column(): + with gr.Row(): + enabled_cb = gr.Checkbox(value= i==0, label="") + process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None') + model_id = gr.Dropdown(label="ControlNet", choices=controlnet.list_models(), value='None') + ui_common.create_refresh_button(model_id, controlnet.list_models, lambda: {"choices": controlnet.list_models(refresh=True)}, 'refresh_control_models') + model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10) + control_start = gr.Slider(label="Start", minimum=0.0, maximum=1.0, step=0.05, value=0) + control_end = gr.Slider(label="End", minimum=0.0, maximum=1.0, step=0.05, value=1.0) + reset_btn = ui_components.ToolButton(value=ui_symbols.reset) + image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) + process_btn= ui_components.ToolButton(value=ui_symbols.preview) + controlnet_ui_units.append(unit_ui) + units.append(unit.Unit( + unit_type = 'controlnet', + result_txt = result_txt, + image_input = input_image, + enabled_cb = enabled_cb, + reset_btn = reset_btn, + process_id = process_id, + model_id = model_id, + model_strength = model_strength, + preview_process = preview_process, + preview_btn = process_btn, + image_upload = image_upload, + control_start = control_start, + control_end = control_end, + extra_controls = extra_controls, + ) + ) + if i == 0: + units[-1].enabled = True # enable first unit in group + num_controlnet_units.change(fn=display_units, inputs=[num_controlnet_units], outputs=controlnet_ui_units) + + with gr.Tab('XS') as _tab_controlnetxs: + gr.HTML('ControlNet XS') + with gr.Row(): + extra_controls = [ + gr.Slider(label="Time embedding mix", minimum=0.0, maximum=1.0, step=0.05, value=0.0, scale=3) + ] + num_controlnet_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1) + controlnetxs_ui_units = [] # list of hidable accordions + for i in range(max_units): + with gr.Accordion(f'Control unit {i+1}', visible= i < num_controlnet_units.value) as unit_ui: + with gr.Row(): + with gr.Column(): + with gr.Row(): + enabled_cb = gr.Checkbox(value= i==0, label="") + process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None') + model_id = gr.Dropdown(label="ControlNet-XS", choices=xs.list_models(), value='None') + ui_common.create_refresh_button(model_id, xs.list_models, lambda: {"choices": xs.list_models(refresh=True)}, 'refresh_control_models') + model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10) + control_start = gr.Slider(label="Start", minimum=0.0, maximum=1.0, step=0.05, value=0) + control_end = gr.Slider(label="End", minimum=0.0, maximum=1.0, step=0.05, value=1.0) + reset_btn = ui_components.ToolButton(value=ui_symbols.reset) + image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) + process_btn= ui_components.ToolButton(value=ui_symbols.preview) + controlnetxs_ui_units.append(unit_ui) + units.append(unit.Unit( + unit_type = 'xs', + result_txt = result_txt, + image_input = input_image, + enabled_cb = enabled_cb, + reset_btn = reset_btn, + process_id = process_id, + model_id = model_id, + model_strength = model_strength, + preview_process = preview_process, + preview_btn = process_btn, + image_upload = image_upload, + control_start = control_start, + control_end = control_end, + extra_controls = extra_controls, + ) + ) + if i == 0: + units[-1].enabled = True # enable first unit in group + num_controlnet_units.change(fn=display_units, inputs=[num_controlnet_units], outputs=controlnetxs_ui_units) + + with gr.Tab('Adapter') as _tab_adapter: + gr.HTML('T2I-Adapter') + with gr.Row(): + extra_controls = [ + gr.Slider(label="Control factor", minimum=0.0, maximum=1.0, step=0.05, value=1.0, scale=3), + ] + num_adapter_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1) + adapter_ui_units = [] # list of hidable accordions + for i in range(max_units): + with gr.Accordion(f'Adapter unit {i+1}', visible= i < num_adapter_units.value) as unit_ui: + with gr.Row(): + with gr.Column(): + with gr.Row(): + enabled_cb = gr.Checkbox(value= i == 0, label="Enabled") + process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None') + model_id = gr.Dropdown(label="Adapter", choices=t2iadapter.list_models(), value='None') + ui_common.create_refresh_button(model_id, t2iadapter.list_models, lambda: {"choices": t2iadapter.list_models(refresh=True)}, 'refresh_adapter_models') + model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10) + reset_btn = ui_components.ToolButton(value=ui_symbols.reset) + image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) + process_btn= ui_components.ToolButton(value=ui_symbols.preview) + adapter_ui_units.append(unit_ui) + units.append(unit.Unit( + unit_type = 'adapter', + result_txt = result_txt, + image_input = input_image, + enabled_cb = enabled_cb, + reset_btn = reset_btn, + process_id = process_id, + model_id = model_id, + model_strength = model_strength, + preview_process = preview_process, + preview_btn = process_btn, + image_upload = image_upload, + extra_controls = extra_controls, + ) + ) + if i == 0: + units[-1].enabled = True # enable first unit in group + num_adapter_units.change(fn=display_units, inputs=[num_adapter_units], outputs=adapter_ui_units) + + with gr.Tab('Lite') as _tab_lite: + gr.HTML('Control LLLite') + with gr.Row(): + extra_controls = [ + ] + num_lite_units = gr.Slider(label="Units", minimum=1, maximum=max_units, step=1, value=1, scale=1) + lite_ui_units = [] # list of hidable accordions + for i in range(max_units): + with gr.Accordion(f'Control unit {i+1}', visible= i < num_lite_units.value) as unit_ui: + with gr.Row(): + with gr.Column(): + with gr.Row(): + enabled_cb = gr.Checkbox(value= i == 0, label="Enabled") + process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None') + model_id = gr.Dropdown(label="Model", choices=lite.list_models(), value='None') + ui_common.create_refresh_button(model_id, lite.list_models, lambda: {"choices": lite.list_models(refresh=True)}, 'refresh_lite_models') + model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0-i/10) + reset_btn = ui_components.ToolButton(value=ui_symbols.reset) + image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) + process_btn= ui_components.ToolButton(value=ui_symbols.preview) + lite_ui_units.append(unit_ui) + units.append(unit.Unit( + unit_type = 'lite', + result_txt = result_txt, + image_input = input_image, + enabled_cb = enabled_cb, + reset_btn = reset_btn, + process_id = process_id, + model_id = model_id, + model_strength = model_strength, + preview_process = preview_process, + preview_btn = process_btn, + image_upload = image_upload, + extra_controls = extra_controls, + ) + ) + if i == 0: + units[-1].enabled = True # enable first unit in group + num_lite_units.change(fn=display_units, inputs=[num_lite_units], outputs=lite_ui_units) + + with gr.Tab('Reference') as _tab_reference: + gr.HTML('ControlNet reference-only control') + with gr.Row(): + extra_controls = [ + gr.Radio(label="Reference context", choices=['Attention', 'Adain', 'Attention Adain'], value='Attention', interactive=True), + gr.Slider(label="Style fidelity", minimum=0.0, maximum=1.0, step=0.05, value=0.5, interactive=True), # prompt vs control importance + gr.Slider(label="Reference query weight", minimum=0.0, maximum=1.0, step=0.05, value=1.0, interactive=True), + gr.Slider(label="Reference adain weight", minimum=0.0, maximum=2.0, step=0.05, value=1.0, interactive=True), + ] + for i in range(1): # can only have one reference unit + with gr.Accordion(f'Reference unit {i+1}', visible=True) as unit_ui: + with gr.Row(): + with gr.Column(): + with gr.Row(): + enabled_cb = gr.Checkbox(value= i == 0, label="Enabled", visible=False) + model_id = gr.Dropdown(label="Reference", choices=reference.list_models(), value='Reference', visible=False) + model_strength = gr.Slider(label="Strength", minimum=0.01, maximum=1.0, step=0.01, value=1.0, visible=False) + reset_btn = ui_components.ToolButton(value=ui_symbols.reset) + image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) + process_btn= ui_components.ToolButton(value=ui_symbols.preview) + units.append(unit.Unit( + unit_type = 'reference', + result_txt = result_txt, + image_input = input_image, + enabled_cb = enabled_cb, + reset_btn = reset_btn, + process_id = process_id, + model_id = model_id, + model_strength = model_strength, + preview_process = preview_process, + preview_btn = process_btn, + image_upload = image_upload, + extra_controls = extra_controls, + ) + ) + if i == 0: + units[-1].enabled = True # enable first unit in group + + with gr.Tab('Processor settings') as _tab_settings: + with gr.Group(elem_classes=['processor-group']): + settings = [] + with gr.Accordion('HED', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Scribble", value=False)) + with gr.Accordion('Midas depth', open=True, elem_classes=['processor-settings']): + settings.append(gr.Slider(label="Background threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.1)) + settings.append(gr.Checkbox(label="Depth and normal", value=False)) + with gr.Accordion('MLSD', open=True, elem_classes=['processor-settings']): + settings.append(gr.Slider(label="Score threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.1)) + settings.append(gr.Slider(label="Distance threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.1)) + with gr.Accordion('OpenBody', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Body", value=True)) + settings.append(gr.Checkbox(label="Hands", value=False)) + settings.append(gr.Checkbox(label="Face", value=False)) + with gr.Accordion('PidiNet', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Scribble", value=False)) + settings.append(gr.Checkbox(label="Apply filter", value=False)) + with gr.Accordion('LineArt', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Coarse", value=False)) + with gr.Accordion('Leres Depth', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Boost", value=False)) + settings.append(gr.Slider(label="Near threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.0)) + settings.append(gr.Slider(label="Background threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.0)) + with gr.Accordion('MediaPipe Face', open=True, elem_classes=['processor-settings']): + settings.append(gr.Slider(label="Max faces", minimum=1, maximum=10, step=1, value=1)) + settings.append(gr.Slider(label="Min confidence", minimum=0.0, maximum=1.0, step=0.01, value=0.5)) + with gr.Accordion('Canny', open=True, elem_classes=['processor-settings']): + settings.append(gr.Slider(label="Low threshold", minimum=0, maximum=1000, step=1, value=100)) + settings.append(gr.Slider(label="High threshold", minimum=0, maximum=1000, step=1, value=200)) + with gr.Accordion('DWPose', open=True, elem_classes=['processor-settings']): + settings.append(gr.Radio(label="Model", choices=['Tiny', 'Medium', 'Large'], value='Tiny')) + settings.append(gr.Slider(label="Min confidence", minimum=0.0, maximum=1.0, step=0.01, value=0.3)) + with gr.Accordion('SegmentAnything', open=True, elem_classes=['processor-settings']): + settings.append(gr.Radio(label="Model", choices=['Base', 'Large'], value='Base')) + with gr.Accordion('Edge', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Parameter free", value=True)) + settings.append(gr.Radio(label="Mode", choices=['edge', 'gradient'], value='edge')) + with gr.Accordion('Zoe Depth', open=True, elem_classes=['processor-settings']): + settings.append(gr.Checkbox(label="Gamma corrected", value=False)) + for setting in settings: + setting.change(fn=processors.update_settings, inputs=settings, outputs=[]) + + tabs_state = gr.Text(value='none', visible=False) + input_fields = [ + input_type, + prompt, negative, styles, + steps, sampler_index, + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, + cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, full_quality, restore_faces, tiling, hdr_clamp, hdr_boundary, hdr_threshold, hdr_center, hdr_channel_shift, hdr_full_shift, hdr_maximize, hdr_max_center, hdr_max_boundry, + resize_mode, resize_name, width, height, scale_by, selected_scale_tab, resize_time, + denoising_strength, batch_count, batch_size, + video_skip_frames, video_type, video_duration, video_loop, video_pad, video_interpolate, + ip_adapter, ip_scale, ip_image, ip_type, + ] + output_fields = [ + preview_process, + output_image, + output_video, + output_gallery, + result_txt, + ] + paste_fields = [] # TODO paste fields + + control_dict = dict( + fn=generate_click, + _js="submit_control", + inputs=[tabs_state, tabs_state] + input_fields, + outputs=output_fields, + show_progress=False, + ) + prompt.submit(**control_dict) + btn_generate.click(**control_dict) + + btn_prompt_counter.click(fn=call_queue.wrap_queued_call(ui.update_token_counter), inputs=[prompt, steps], outputs=[prompt_counter]) + btn_negative_counter.click(fn=call_queue.wrap_queued_call(ui.update_token_counter), inputs=[negative, steps], outputs=[negative_counter]) + + generation_parameters_copypaste.add_paste_fields("control", input_image, paste_fields, override_settings) + bindings = generation_parameters_copypaste.ParamBinding(paste_button=btn_paste, tabname="control", source_text_component=prompt, source_image_component=output_gallery) + generation_parameters_copypaste.register_paste_params_button(bindings) + + if os.environ.get('SD_CONTROL_DEBUG', None) is not None: # debug only + from modules.control.test import test_processors, test_controlnets, test_adapters, test_xs, test_lite + gr.HTML('

Debug


') + with gr.Row(): + run_test_processors_btn = gr.Button(value="Test:Processors", variant='primary', elem_classes=['control-button']) + run_test_controlnets_btn = gr.Button(value="Test:ControlNets", variant='primary', elem_classes=['control-button']) + run_test_xs_btn = gr.Button(value="Test:ControlNets-XS", variant='primary', elem_classes=['control-button']) + run_test_adapters_btn = gr.Button(value="Test:Adapters", variant='primary', elem_classes=['control-button']) + run_test_lite_btn = gr.Button(value="Test:Control-LLLite", variant='primary', elem_classes=['control-button']) + + run_test_processors_btn.click(fn=test_processors, inputs=[input_image], outputs=[preview_process, output_image, output_video, output_gallery]) + run_test_controlnets_btn.click(fn=test_controlnets, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery]) + run_test_xs_btn.click(fn=test_xs, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery]) + run_test_adapters_btn.click(fn=test_adapters, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery]) + run_test_lite_btn.click(fn=test_lite, inputs=[prompt, negative, input_image], outputs=[preview_process, output_image, output_video, output_gallery]) + + return [(control_ui, 'Control', 'control')] diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6a3e1e6b1..8cf67bc67 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -250,13 +250,14 @@ def refresh_extensions_list(search_text, sort_column): global extensions_list # pylint: disable=global-statement import urllib.request try: - with urllib.request.urlopen(extensions_index) as response: + shared.log.debug(f'Updating extensions list: url={extensions_index}') + with urllib.request.urlopen(extensions_index, timeout=3.0) as response: text = response.read() extensions_list = json.loads(text) with open(os.path.join(paths.script_path, "html", "extensions.json"), "w", encoding="utf-8") as outfile: json_object = json.dumps(extensions_list, indent=2) outfile.write(json_object) - shared.log.debug(f'Updated extensions list: {len(extensions_list)} {extensions_index}') + shared.log.info(f'Updated extensions list: items={len(extensions_list)} url={extensions_index}') except Exception as e: shared.log.warning(f'Updated extensions list failed: {extensions_index} {e}') list_extensions() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f30dba2cc..e996940a2 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -24,7 +24,8 @@ dir_cache = {} # key=path, value=(mtime, listdir(path)) refresh_time = 0 extra_pages = shared.extra_networks -debug = shared.log.info if os.environ.get('SD_EN_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = shared.log.trace if os.environ.get('SD_EN_DEBUG', None) is not None else lambda *args, **kwargs: None +debug('Trace: EN') card_full = '''
@@ -232,11 +233,10 @@ def create_page(self, tabname, skip = False): allowed_folders = [os.path.abspath(x) for x in self.allowed_directories_for_previews()] for parentdir, dirs in {d: modelloader.directory_list(d) for d in allowed_folders}.items(): for tgt in dirs.keys(): - if shared.backend == shared.Backend.DIFFUSERS: - if os.path.join(paths.models_path, 'Reference') in tgt: - subdirs['Reference'] = 1 - if shared.opts.diffusers_dir in tgt: - subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1 + if os.path.join(paths.models_path, 'Reference') in tgt: + subdirs['Reference'] = 1 + if shared.backend == shared.Backend.DIFFUSERS and shared.opts.diffusers_dir in tgt: + subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1 if 'models--' in tgt: continue subdir = tgt[len(parentdir):].replace("\\", "/") @@ -247,7 +247,7 @@ def create_page(self, tabname, skip = False): subdirs[subdir] = 1 debug(f"Extra networks: page='{self.name}' subfolders={list(subdirs)}") subdirs = OrderedDict(sorted(subdirs.items())) - if shared.backend == shared.Backend.DIFFUSERS and self.name == 'model': + if self.name == 'model': subdirs['Reference'] = 1 subdirs[os.path.basename(shared.opts.diffusers_dir)] = 1 subdirs.move_to_end(os.path.basename(shared.opts.diffusers_dir)) @@ -343,7 +343,6 @@ def handle_endtag(self, tag): self.text += '\n' fn = os.path.splitext(path)[0] + '.txt' - # if os.path.exists(fn): if fn in listdir(os.path.dirname(path)): try: with open(fn, "r", encoding="utf-8", errors="replace") as f: @@ -364,7 +363,6 @@ def handle_endtag(self, tag): def find_info(self, path): t0 = time.time() fn = os.path.splitext(path)[0] + '.json' - # if os.path.exists(fn): data = {} if fn in listdir(os.path.dirname(path)): data = shared.readfile(fn, silent=True) @@ -382,12 +380,15 @@ def initialize(): def register_page(page: ExtraNetworksPage): # registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions debug(f'EN register-page: {page}') + if page in shared.extra_networks: + debug(f'EN register-page: {page} already registered') + return shared.extra_networks.append(page) - allowed_dirs.clear() - for pg in shared.extra_networks: - for folder in pg.allowed_directories_for_previews(): - if folder not in allowed_dirs: - allowed_dirs.append(os.path.abspath(folder)) + # allowed_dirs.clear() + # for pg in shared.extra_networks: + for folder in page.allowed_directories_for_previews(): + if folder not in allowed_dirs: + allowed_dirs.append(os.path.abspath(folder)) def register_pages(): @@ -396,6 +397,7 @@ def register_pages(): from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints from modules.ui_extra_networks_styles import ExtraNetworksPageStyles from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs + debug('EN register-pages') register_page(ExtraNetworksPageCheckpoints()) register_page(ExtraNetworksPageStyles()) register_page(ExtraNetworksPageTextualInversion()) @@ -457,11 +459,11 @@ def create_ui(container, button_parent, tabname, skip_indexing = False): ui = ExtraNetworksUi() ui.tabname = tabname ui.pages = [] - ui.state = gr.Textbox('{}', elem_id=tabname+"_extra_state", visible=False) + ui.state = gr.Textbox('{}', elem_id=f"{tabname}_extra_state", visible=False) ui.visible = gr.State(value=False) # pylint: disable=abstract-class-instantiated - ui.details = gr.Group(elem_id=tabname+"_extra_details", visible=False) - ui.tabs = gr.Tabs(elem_id=tabname+"_extra_tabs") - ui.button_details = gr.Button('Details', elem_id=tabname+"_extra_details_btn", visible=False) + ui.details = gr.Group(elem_id=f"{tabname}_extra_details", visible=False) + ui.tabs = gr.Tabs(elem_id=f"{tabname}_extra_tabs") + ui.button_details = gr.Button('Details', elem_id=f"{tabname}_extra_details_btn", visible=False) state = {} if shared.cmd_opts.profile: import cProfile @@ -502,14 +504,14 @@ def toggle_visibility(is_visible): return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) with ui.details: - details_close = ToolButton(symbols.close, elem_id=tabname+"_extra_details_close", elem_classes=['extra-details-close']) + details_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_details_close", elem_classes=['extra-details-close']) details_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[ui.details]) with gr.Row(): with gr.Column(scale=1): text = gr.HTML('
title
') ui.details_components.append(text) with gr.Column(scale=1): - img = gr.Image(value=None, show_label=False, interactive=False, container=False, show_download_button=False, show_info=False, elem_id=tabname+"_extra_details_img", elem_classes=['extra-details-img']) + img = gr.Image(value=None, show_label=False, interactive=False, container=False, show_download_button=False, show_info=False, elem_id=f"{tabname}_extra_details_img", elem_classes=['extra-details-img']) ui.details_components.append(img) with gr.Row(): btn_save_img = gr.Button('Replace', elem_classes=['small-button']) @@ -542,29 +544,30 @@ def ui_tab_change(page): model_visible = page in ['Model'] return [gr.update(visible=scan_visible), gr.update(visible=save_visible), gr.update(visible=model_visible)] - ui.button_refresh = ToolButton(symbols.refresh, elem_id=tabname+"_extra_refresh") - ui.button_scan = ToolButton(symbols.scan, elem_id=tabname+"_extra_scan", visible=True) - ui.button_quicksave = ToolButton(symbols.book, elem_id=tabname+"_extra_quicksave", visible=False) - ui.button_save = ToolButton(symbols.book, elem_id=tabname+"_extra_save", visible=False) - ui.button_sort = ToolButton(symbols.sort, elem_id=tabname+"_extra_sort", visible=True) - ui.button_view = ToolButton(symbols.view, elem_id=tabname+"_extra_view", visible=True) - ui.button_close = ToolButton(symbols.close, elem_id=tabname+"_extra_close", visible=True) - ui.button_model = ToolButton(symbols.refine, elem_id=tabname+"_extra_model", visible=True) - ui.search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False) - ui.description = gr.Textbox('', show_label=False, elem_id=tabname+"_description", elem_classes="textbox", lines=2, interactive=False, container=False) + ui.button_refresh = ToolButton(symbols.refresh, elem_id=f"{tabname}_extra_refresh") + ui.button_scan = ToolButton(symbols.scan, elem_id=f"{tabname}_extra_scan", visible=True) + ui.button_quicksave = ToolButton(symbols.book, elem_id=f"{tabname}_extra_quicksave", visible=False) + ui.button_save = ToolButton(symbols.book, elem_id=f"{tabname}_extra_save", visible=False) + ui.button_sort = ToolButton(symbols.sort, elem_id=f"{tabname}_extra_sort", visible=True) + ui.button_view = ToolButton(symbols.view, elem_id=f"{tabname}_extra_view", visible=True) + ui.button_close = ToolButton(symbols.close, elem_id=f"{tabname}_extra_close", visible=True) + ui.button_model = ToolButton(symbols.refine, elem_id=f"{tabname}_extra_model", visible=True) + ui.search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False) + ui.description = gr.Textbox('', show_label=False, elem_id=f"{tabname}_description", elem_classes="textbox", lines=2, interactive=False, container=False) if ui.tabname == 'txt2img': # refresh only once global refresh_time # pylint: disable=global-statement refresh_time = time.time() - threads = [] - for page in get_pages(): - if os.environ.get('SD_EN_DEBUG', None) is not None: - threads.append(threading.Thread(target=page.create_items, args=[ui.tabname])) - threads[-1].start() - else: - page.create_items(ui.tabname) - for thread in threads: - thread.join() + if not skip_indexing: + threads = [] + for page in get_pages(): + if os.environ.get('SD_EN_DEBUG', None) is not None: + threads.append(threading.Thread(target=page.create_items, args=[ui.tabname])) + threads[-1].start() + else: + page.create_items(ui.tabname) + for thread in threads: + thread.join() for page in get_pages(): page.create_page(ui.tabname, skip_indexing) with gr.Tab(page.title, id=page.title.lower().replace(" ", "_"), elem_classes="extra-networks-tab") as tab: @@ -574,7 +577,6 @@ def ui_tab_change(page): if shared.cmd_opts.profile: errors.profile(pr, 'ExtraNetworks') pr.disable() - # ui.tabs.change(fn=ui_tab_change, inputs=[], outputs=[ui.button_scan, ui.button_save]) def fn_save_img(image): @@ -795,21 +797,25 @@ def ui_quicksave_click(name): prompt = '' params = generation_parameters_copypaste.parse_generation_parameters(prompt) fn = os.path.join(shared.opts.styles_dir, os.path.splitext(name)[0] + '.json') + prompt = params.get('Prompt', '') item = { - "type": 'Style', "name": name, - "title": name, - "filename": fn, - "search_term": None, - "preview": None, "description": '', - "prompt": params.get('Prompt', ''), + "prompt": prompt, "negative": params.get('Negative prompt', ''), "extra": '', - "local_preview": None, + # "type": 'Style', + # "title": name, + # "filename": fn, + # "search_term": None, + # "preview": None, + # "local_preview": None, } shared.writefile(item, fn, silent=True) - shared.log.debug(f"Extra network quick save style: item={item['name']} filename='{fn}'") + if len(prompt) > 0: + shared.log.debug(f"Extra network quick save style: item={name} filename='{fn}'") + else: + shared.log.warning(f"Extra network quick save model: item={name} filename='{fn}' prompt is empty") def ui_sort_cards(msg): shared.log.debug(f'Extra networks: {msg}') diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 692acdc18..c217348c0 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -15,21 +15,25 @@ def refresh(self): shared.refresh_checkpoints() def list_reference(self): # pylint: disable=inconsistent-return-statements - if shared.backend != shared.Backend.DIFFUSERS: - return [] reference_models = shared.readfile(os.path.join('html', 'reference.json')) for k, v in reference_models.items(): + if shared.backend != shared.Backend.DIFFUSERS: + if not v.get('original', False): + continue + url = v.get('alt', None) or v['path'] + else: + url = v['path'] name = os.path.join(reference_dir, k) preview = v.get('preview', v['path']) yield { "type": 'Model', "name": name, "title": name, - "filename": v['path'], + "filename": url, "search_term": self.search_terms_from_path(name), "preview": self.find_preview(os.path.join(reference_dir, preview)), "local_preview": self.find_preview_file(os.path.join(reference_dir, preview)), - "onclick": '"' + html.escape(f"""return selectReference({json.dumps(v['path'])})""") + '"', + "onclick": '"' + html.escape(f"""return selectReference({json.dumps(url)})""") + '"', "hash": None, "mtime": 0, "size": 0, @@ -74,4 +78,7 @@ def list_items(self): yield record def allowed_directories_for_previews(self): - return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, reference_dir, sd_models.model_path] if v is not None] + if shared.backend == shared.Backend.DIFFUSERS: + return [v for v in [shared.opts.ckpt_dir, shared.opts.diffusers_dir, reference_dir] if v is not None] + else: + return [v for v in [shared.opts.ckpt_dir, reference_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_styles.py b/modules/ui_extra_networks_styles.py index 8dcccf96e..9f46850e9 100644 --- a/modules/ui_extra_networks_styles.py +++ b/modules/ui_extra_networks_styles.py @@ -53,7 +53,7 @@ def create_style(self, params): "name": name, "title": name, "filename": fn, - "search_term": f'{self.search_terms_from_path(name)}', + "search_term": f'{self.search_terms_from_path(fn)} {params.get("Prompt", "")}', "preview": self.find_preview(name), "description": '', "prompt": params.get('Prompt', ''), @@ -79,7 +79,7 @@ def create_item(self, k): "name": name, "title": k, "filename": style.filename, - "search_term": f'{txt} {self.search_terms_from_path(name)}', + "search_term": f'{self.search_terms_from_path(name)} {txt}', "preview": style.preview if getattr(style, 'preview', None) is not None and style.preview.startswith('data:') else self.find_preview(fn), "description": style.description if getattr(style, 'description', None) is not None and len(style.description) > 0 else txt, "prompt": getattr(style, 'prompt', ''), diff --git a/modules/ui_interrogate.py b/modules/ui_interrogate.py index b9bb074ea..98fbe9a2f 100644 --- a/modules/ui_interrogate.py +++ b/modules/ui_interrogate.py @@ -195,7 +195,7 @@ def create_ui(): analyze_btn = gr.Button("Analyze", variant='primary') unload_btn = gr.Button("Unload") with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras"]) + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras", "control"]) for tabname, button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=prompt, source_image_component=image,)) with gr.Tab("Batch"): diff --git a/modules/ui_models.py b/modules/ui_models.py index ae4baab15..ab9870c05 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -488,12 +488,12 @@ def civit_select3(evt: gr.SelectData, in_data): log.debug(f'CivitAI select: variant={in_data[evt.index[0]]}') return in_data[evt.index[0]][3], in_data[evt.index[0]][0], gr.update(interactive=True) - def civit_download_model(model_url: str, model_name: str, model_path: str, model_type: str, image_url: str): + def civit_download_model(model_url: str, model_name: str, model_path: str, model_type: str, image_url: str, token: str = None): if model_url is None or len(model_url) == 0: return 'No model selected' try: from modules.modelloader import download_civit_model - res = download_civit_model(model_url, model_name, model_path, model_type, image_url) + res = download_civit_model(model_url, model_name, model_path, model_type, image_url, token) except Exception as e: res = f"CivitAI model downloaded error: model={model_url} {e}" log.error(res) @@ -515,7 +515,8 @@ def civit_search_metadata(civit_previews_rehash, title): continue for item in page.list_items(): meta = os.path.splitext(item['filename'])[0] + '.json' - if ('card-no-preview.png' in item['preview'] or not os.path.isfile(meta)) and os.path.isfile(item['filename']): + has_meta = os.path.isfile(meta) and os.stat(meta).st_size > 0 + if ('card-no-preview.png' in item['preview'] or not has_meta) and os.path.isfile(item['filename']): sha = item.get('hash', None) found = False if sha is not None and len(sha) > 0: @@ -577,6 +578,8 @@ def civit_search_metadata(civit_previews_rehash, title): with gr.Row(): civit_download_model_btn = gr.Button(value="Download", variant='primary') gr.HTML('Select a model, model version and and model variant from the search results to download or enter model URL manually
') + with gr.Row(): + civit_token = gr.Textbox('', label='CivitAI token', placeholder='optional access token for private or gated models') with gr.Row(): civit_name = gr.Textbox('', label='Model name', placeholder='select model from search results', visible=True) civit_selected = gr.Textbox('', label='Model URL', placeholder='select model from search results', visible=True) @@ -619,7 +622,7 @@ def is_visible(component): civit_results1.change(fn=is_visible, inputs=[civit_results1], outputs=[civit_results1]) civit_results2.change(fn=is_visible, inputs=[civit_results2], outputs=[civit_results2]) civit_results3.change(fn=is_visible, inputs=[civit_results3], outputs=[civit_results3]) - civit_download_model_btn.click(fn=civit_download_model, inputs=[civit_selected, civit_name, civit_path, civit_model_type, models_image], outputs=[models_outcome]) + civit_download_model_btn.click(fn=civit_download_model, inputs=[civit_selected, civit_name, civit_path, civit_model_type, models_image, civit_token], outputs=[models_outcome]) civit_previews_btn.click(fn=civit_search_metadata, inputs=[civit_previews_rehash, civit_previews_rehash], outputs=[models_outcome]) with gr.Tab(label="Update"): diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 334afcc9a..0fb103859 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -12,8 +12,8 @@ def wrap_pnginfo(image): return infotext_to_html(geninfo), info, geninfo -def submit_click(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs): - result_images, geninfo, js_info = postprocessing.run_postprocessing(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs) +def submit_click(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, save_output, *script_inputs): + result_images, geninfo, js_info = postprocessing.run_postprocessing(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs, save_output=save_output) return result_images, geninfo, json.dumps(js_info), '' @@ -31,7 +31,9 @@ def create_ui(): extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint"]) + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "control"]) + with gr.Row(): + save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output") script_inputs = scripts.scripts_postproc.setup_ui() with gr.Column(): id_part = 'extras' @@ -66,6 +68,7 @@ def create_ui(): extras_batch_input_dir, extras_batch_output_dir, show_extras_results, + save_output, *script_inputs, ], outputs=[ diff --git a/modules/ui_symbols.py b/modules/ui_symbols.py index 0e80508ec..509df747b 100644 --- a/modules/ui_symbols.py +++ b/modules/ui_symbols.py @@ -18,6 +18,9 @@ random = '🎲️' reuse = '♻️' info = 'ℹ' # noqa +reset = '🔄' +upload = '⬆️' +preview = '🔍' """ refresh = '🔄' close = '🛗' diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 25f59125e..21730cf95 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -8,7 +8,7 @@ Savedfile = namedtuple("Savedfile", ["name"]) -debug = errors.log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None def register_tmp_file(gradio, filename): @@ -68,7 +68,8 @@ def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint: with tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) as tmp: name = tmp.name img.save(name, pnginfo=(metadata if use_metadata else None)) - shared.log.debug(f'Saving temp: image="{name}"') + size = os.path.getsize(name) + shared.log.debug(f'Saving temp: image="{name}" resolution={img.width}x{img.height} size={size}') params = ', '.join([f'{k}: {v}' for k, v in img.info.items()]) params = params[12:] if params.startswith('parameters: ') else params with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: diff --git a/pyproject.toml b/pyproject.toml index 931e8ba5e..1cc12d98a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.ruff] -target-version = "py310" +target-version = "py39" select = [ "F", "E", @@ -39,6 +39,9 @@ exclude = [ "repositories/taming", "repositories/blip", "repositories/codeformer", + "modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet", + "modules/control/units/*_model.py", + "modules/control/units/*_pipe.py", ] ignore = [ "A003", # Class attirbute shadowing builtin @@ -47,17 +50,19 @@ ignore = [ "E731", # Do not assign a `lambda` expression, use a `def` "I001", # Import block is un-sorted or un-formatted "W605", # Invalid escape sequence, messes with some docstrings + "B028", # No explicit stacklevel "B905", # Without explicit scrict "C408", # Rewrite as a literal "E402", # Module level import not at top of file "E721", # Do not compare types, use `isinstance()` - "F401", # Imported but unused "EXE001", # Shebang present + "F401", # Imported but unused "ISC003", # Implicit string concatenation "RUF005", # Consider concatenation "RUF012", # Mutable class attributes "RUF013", # Implict optional "RUF015", # Prefer `next` + "TID252", # Relative imports from parent modules ] [tool.ruff.flake8-bugbear] diff --git a/requirements.txt b/requirements.txt index bf54222d3..ab5c24b24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,6 @@ lmdb lpips omegaconf open-clip-torch -opencv-contrib-python-headless piexif psutil pyyaml @@ -42,6 +41,7 @@ scikit-image basicsr fasteners dctorch +pymatting httpx==0.24.1 compel==2.0.2 torchsde==0.2.6 @@ -49,12 +49,12 @@ clip-interrogator==0.6.0 antlr4-python3-runtime==4.9.3 requests==2.31.0 tqdm==4.66.1 -accelerate==0.24.1 -opencv-python-headless==4.7.0.72 -diffusers==0.24.0 +accelerate==0.25.0 +opencv-contrib-python-headless==4.8.1.78 +diffusers==0.25.0 einops==0.4.1 gradio==3.43.2 -huggingface_hub==0.19.4 +huggingface_hub==0.20.1 numexpr==2.8.4 numpy==1.24.4 numba==0.57.1 @@ -62,11 +62,11 @@ pandas==1.5.3 protobuf==3.20.3 pytorch_lightning==1.9.4 tokenizers==0.15.0 -transformers==4.35.2 +transformers==4.36.2 tomesd==0.1.3 -urllib3==1.26.15 +urllib3==1.26.18 Pillow==10.1.0 -timm==0.9.7 +timm==0.9.12 pydantic==1.10.13 -typing-extensions==4.8.0 +typing-extensions==4.9.0 peft diff --git a/scripts/animatediff.py b/scripts/animatediff.py index 4b0ef5d9c..074d31d6b 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -10,6 +10,7 @@ - AnimateFace: https://huggingface.co/nlper2022/animatediff_face_512/tree/main """ +import os import gradio as gr import diffusers from modules import scripts, processing, shared, devices, sd_models @@ -18,12 +19,15 @@ # config ADAPTERS = { 'None': None, - 'Motion 1.4': 'guoyww/animatediff-motion-adapter-v1-4', - 'Motion 1.5 v1': 'guoyww/animatediff-motion-adapter-v1-5', + 'Motion 1.5 v3' :'vladmandic/animatediff-v3', 'Motion 1.5 v2' :'guoyww/animatediff-motion-adapter-v1-5-2', - # 'Motion SD-XL Beta v1' :'vladmandic/animatediff-sdxl', + 'Motion 1.5 v1': 'guoyww/animatediff-motion-adapter-v1-5', + 'Motion 1.4': 'guoyww/animatediff-motion-adapter-v1-4', 'TemporalDiff': 'vladmandic/temporaldiff', 'AnimateFace': 'vladmandic/animateface', + # 'LongAnimateDiff 32': 'vladmandic/longanimatediff-32', + # 'LongAnimateDiff 64': 'vladmandic/longanimatediff-64', + # 'Motion SD-XL Beta v1' :'vladmandic/animatediff-sdxl', } LORAS = { 'None': None, @@ -70,6 +74,10 @@ def set_adapter(adapter_name: str = 'None'): # shared.sd_model.image_encoder = None shared.sd_model.unet.set_default_attn_processor() shared.sd_model.unet.config.encoder_hid_dim_type = None + if adapter_name.endswith('.ckpt') or adapter_name.endswith('.safetensors'): + import huggingface_hub as hf + folder, filename = os.path.split(adapter_name) + adapter_name = hf.hf_hub_download(repo_id=folder, filename=filename, cache_dir=shared.opts.diffusers_dir) try: shared.log.info(f'AnimateDiff load: adapter="{adapter_name}"') motion_adapter = None @@ -77,24 +85,21 @@ def set_adapter(adapter_name: str = 'None'): motion_adapter.to(shared.device) sd_models.set_diffuser_options(motion_adapter, vae=None, op='adapter') loaded_adapter = adapter_name - new_pipe = diffusers.AnimateDiffPipeline( vae=shared.sd_model.vae, text_encoder=shared.sd_model.text_encoder, tokenizer=shared.sd_model.tokenizer, unet=shared.sd_model.unet, scheduler=shared.sd_model.scheduler, + feature_extractor=getattr(shared.sd_model, 'feature_extractor', None), + image_encoder=getattr(shared.sd_model, 'image_encoder', None), motion_adapter=motion_adapter, ) orig_pipe = shared.sd_model - new_pipe.sd_checkpoint_info = shared.sd_model.sd_checkpoint_info - new_pipe.sd_model_hash = shared.sd_model.sd_model_hash - new_pipe.sd_model_checkpoint = shared.sd_model.sd_checkpoint_info.filename - new_pipe.is_sdxl = False - new_pipe.is_sd2 = False - new_pipe.is_sd1 = True shared.sd_model = new_pipe - shared.sd_model.to(shared.device) + if not ((shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) or (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram)): + shared.sd_model.to(shared.device) + sd_models.copy_diffuser_options(new_pipe, orig_pipe) sd_models.set_diffuser_options(shared.sd_model, vae=None, op='model') shared.log.debug(f'AnimateDiff create pipeline: adapter="{loaded_adapter}"') except Exception as e: @@ -123,12 +128,14 @@ def video_type_change(video_type): with gr.Accordion('AnimateDiff', open=False, elem_id='animatediff'): with gr.Row(): adapter_index = gr.Dropdown(label='Adapter', choices=list(ADAPTERS), value='None') - frames = gr.Slider(label='Frames', minimum=1, maximum=32, step=1, value=16) + frames = gr.Slider(label='Frames', minimum=1, maximum=64, step=1, value=16) + with gr.Row(): + override_scheduler = gr.Checkbox(label='Override sampler', value=True) with gr.Row(): lora_index = gr.Dropdown(label='Lora', choices=list(LORAS), value='None') strength = gr.Slider(label='Strength', minimum=0.0, maximum=2.0, step=0.05, value=1.0) with gr.Row(): - latent_mode = gr.Checkbox(label='Latent mode', value=False) + latent_mode = gr.Checkbox(label='Latent mode', value=True, visible=False) with gr.Row(): video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) @@ -137,27 +144,43 @@ def video_type_change(video_type): mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) - return [adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] + return [adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, override_scheduler] - def process(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + def process(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, override_scheduler): # pylint: disable=arguments-differ, unused-argument adapter = ADAPTERS[adapter_index] lora = LORAS[lora_index] set_adapter(adapter) if motion_adapter is None: return - shared.log.debug(f'AnimateDiff: adapter="{adapter}" lora="{lora}" strength={strength} video={video_type}') + if override_scheduler: + p.sampler_name = 'Default' + shared.sd_model.scheduler = diffusers.DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + num_train_timesteps=1000, + rescale_betas_zero_snr=False, + set_alpha_to_one=True, + steps_offset=0, + timestep_spacing="linspace", + trained_betas=None, + ) + shared.log.debug(f'AnimateDiff: adapter="{adapter}" lora="{lora}" strength={strength} video={video_type} scheduler={shared.sd_model.scheduler.__class__.__name__ if override_scheduler else p.sampler_name}') if lora is not None and lora != 'None': shared.sd_model.load_lora_weights(lora, adapter_name=lora) shared.sd_model.set_adapters([lora], adapter_weights=[strength]) p.extra_generation_params['AnimateDiff Lora'] = f'{lora}:{strength}' p.extra_generation_params['AnimateDiff'] = loaded_adapter p.do_not_save_grid = True + if 'animatediff' not in p.ops: + p.ops.append('animatediff') p.task_args['num_frames'] = frames p.task_args['num_inference_steps'] = p.steps if not latent_mode: p.task_args['output_type'] = 'np' - def postprocess(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + def postprocess(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, adapter_index, frames, lora_index, strength, latent_mode, video_type, duration, gif_loop, mp4_pad, mp4_interpolate, override_scheduler): # pylint: disable=arguments-differ, unused-argument from modules.images import save_video if video_type != 'None': save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate) diff --git a/scripts/blipdiffusion.py b/scripts/blipdiffusion.py new file mode 100644 index 000000000..50ae01021 --- /dev/null +++ b/scripts/blipdiffusion.py @@ -0,0 +1,43 @@ +import gradio as gr +from modules import scripts, processing, shared, sd_models + + +title = 'BLIP Diffusion' + + +class Script(scripts.Script): + def title(self): + return title + + def show(self, is_img2img): + return is_img2img if shared.backend == shared.Backend.DIFFUSERS else False + + def ui(self, _is_img2img): + with gr.Row(): + source_subject = gr.Textbox(value='', label='Source subject') + with gr.Row(): + target_subject = gr.Textbox(value='', label='Target subject') + with gr.Row(): + prompt_strength = gr.Slider(label='Prompt strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5) + return [source_subject, target_subject, prompt_strength] + + def run(self, p: processing.StableDiffusionProcessing, source_subject, target_subject, prompt_strength): # pylint: disable=arguments-differ, unused-argument + c = shared.sd_model.__class__.__name__ if shared.sd_model is not None else '' + if c != 'BlipDiffusionPipeline': + shared.log.error(f'{title}: model selected={c} required=BLIPDiffusion') + return None + if hasattr(p, 'init_images') and len(p.init_images) > 0: + p.task_args['reference_image'] = p.init_images[0] + p.task_args['prompt'] = [p.prompt] + p.task_args['neg_prompt'] = p.negative_prompt + p.task_args['prompt_strength'] = prompt_strength + p.task_args['source_subject_category'] = [source_subject] + p.task_args['target_subject_category'] = [target_subject] + p.task_args['output_type'] = 'pil' + shared.log.debug(f'BLIP Diffusion: args={p.task_args}') + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) + processed = processing.process_images(p) + return processed + else: + shared.log.error(f'{title}: no init_images') + return None diff --git a/scripts/demofusion.py b/scripts/demofusion.py new file mode 100644 index 000000000..d37552996 --- /dev/null +++ b/scripts/demofusion.py @@ -0,0 +1,1277 @@ +import random +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +import gradio as gr +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import is_accelerate_available, is_accelerate_version +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from modules import scripts, processing, shared, sd_models + + +### Class definition +""" +Credits: https://github.com/PRIS-CV/DemoFusion +Source: https://github.com/PRIS-CV/DemoFusion/blob/main/pipeline_demofusion_sdxl.py +""" + + +def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + return kernel + + +def gaussian_filter(latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + return blurred_latents + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class DemoFusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale # pylint: disable=attribute-defined-outside-init + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + shared.log.warning(f"The following part of your input was truncated because CLIP can only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}") + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt # pylint: disable=no-member + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + num_images_per_prompt=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # DemoFusion specific checks + if max(height, width) % 1024 != 0: + shared.log.error('DemoFusion: resolution={width}x{height} long side must be divisible by 1024') + return None + + if num_images_per_prompt != 1: + shared.log.warning('DemoFusion: number of images per prompt is not support and will be ignored') + num_images_per_prompt = 1 + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def get_views(self, height, width, window_size=128, stride=64, random_jitter=False): + # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) + # if panorama's height/width < window_size, num_blocks of height/width should return 1 + height //= self.vae_scale_factor + width //= self.vae_scale_factor + num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1 + num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + + if h_end > height: + h_start = int(h_start + height - h_end) + h_end = int(height) + if w_end > width: + w_start = int(w_start + width - w_end) + w_end = int(width) + if h_start < 0: + h_end = int(h_end - h_start) + h_start = 0 + if w_start < 0: + w_end = int(w_end - w_start) + w_start = 0 + + if random_jitter: + jitter_range = (window_size - stride) // 4 + w_jitter = 0 + h_jitter = 0 + if (w_start != 0) and (w_end != width): + w_jitter = random.randint(-jitter_range, jitter_range) + elif (w_start == 0) and (w_end != width): + w_jitter = random.randint(-jitter_range, 0) + elif (w_start != 0) and (w_end == width): + w_jitter = random.randint(0, jitter_range) + if (h_start != 0) and (h_end != height): + h_jitter = random.randint(-jitter_range, jitter_range) + elif (h_start == 0) and (h_end != height): + h_jitter = random.randint(-jitter_range, 0) + elif (h_start != 0) and (h_end == height): + h_jitter = random.randint(0, jitter_range) + h_start += (h_jitter + jitter_range) + h_end += (h_jitter + jitter_range) + w_start += (w_jitter + jitter_range) + w_end += (w_jitter + jitter_range) + + views.append((h_start, h_end, w_start, w_end)) + return views + + def tiled_decode(self, latents, current_height, current_width): + core_size = self.unet.config.sample_size // 4 + core_stride = core_size + pad_size = self.unet.config.sample_size // 4 * 3 + decoder_view_batch_size = 1 + + if self.lowvram: + core_stride = core_size // 2 + pad_size = core_size + + views = self.get_views(current_height, current_width, stride=core_stride, window_size=core_size) + views_batch = [views[i : i + decoder_view_batch_size] for i in range(0, len(views), decoder_view_batch_size)] + latents_ = F.pad(latents, (pad_size, pad_size, pad_size, pad_size), 'constant', 0) + image = torch.zeros(latents.size(0), 3, current_height, current_width).to(latents.device) + count = torch.zeros_like(image).to(latents.device) + # get the latents corresponding to the current view coordinates + with self.progress_bar(total=len(views_batch)) as progress_bar: + for j, batch_view in enumerate(views_batch): + len(batch_view) + latents_for_view = torch.cat( + [ + latents_[:, :, h_start:h_end+pad_size*2, w_start:w_end+pad_size*2] + for h_start, h_end, w_start, w_end in batch_view + ] + ).to(self.vae.device) + image_patch = self.vae.decode(latents_for_view / self.vae.config.scaling_factor, return_dict=False)[0] + h_start, h_end, w_start, w_end = views[j] + h_start, h_end, w_start, w_end = h_start * self.vae_scale_factor, h_end * self.vae_scale_factor, w_start * self.vae_scale_factor, w_end * self.vae_scale_factor + p_h_start, p_h_end, p_w_start, p_w_end = pad_size * self.vae_scale_factor, image_patch.size(2) - pad_size * self.vae_scale_factor, pad_size * self.vae_scale_factor, image_patch.size(3) - pad_size * self.vae_scale_factor + image[:, :, h_start:h_end, w_start:w_end] += image_patch[:, :, p_h_start:p_h_end, p_w_start:p_w_end].to(latents.device) + count[:, :, h_start:h_end, w_start:w_end] += 1 + progress_bar.update() + image = image / count + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + ################### DemoFusion specific parameters #################### + view_batch_size: int = 16, + multi_decoder: bool = True, + stride: Optional[int] = 64, + cosine_scale_1: Optional[float] = 3., + cosine_scale_2: Optional[float] = 1., + cosine_scale_3: Optional[float] = 1., + sigma: Optional[float] = 1.0, + lowvram: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + ################### DemoFusion specific parameters #################### + view_batch_size (`int`, defaults to 16): + The batch size for multiple denoising paths. Typically, a larger batch size can result in higher + efficiency but comes with increased GPU memory requirements. + multi_decoder (`bool`, defaults to True): + Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, + a tiled decoder becomes necessary. + stride (`int`, defaults to 64): + The stride of moving local patches. A smaller stride is better for alleviating seam issues, + but it also introduces additional computational overhead and inference time. + cosine_scale_1 (`float`, defaults to 3): + Control the strength of skip-residual. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + cosine_scale_2 (`float`, defaults to 1): + Control the strength of dilated sampling. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + cosine_scale_3 (`float`, defaults to 1): + Control the strength of the gaussion filter. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + sigma (`float`, defaults to 1): + The standard value of the gaussian filter. + show_image (`bool`, defaults to False): + Determine whether to show intermediate results during generation. + lowvram (`bool`, defaults to False): + Try to fit in 8 Gb of VRAM, with xformers installed. + + Examples: + + Returns: + a `list` with the generated images at each phase. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + x1_size = self.default_sample_size * self.vae_scale_factor + + height_scale = height / x1_size + width_scale = width / x1_size + scale_num = int(max(height_scale, width_scale)) + aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale) + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + num_images_per_prompt, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + self.lowvram = lowvram # pylint: disable=attribute-defined-outside-init + if self.lowvram: + self.vae.cpu() + self.unet.cpu() + self.text_encoder.to(device) + self.text_encoder_2.to(device) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height // scale_num, + width // scale_num, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + output_images = [] + + ############################################################### Phase 1 ################################################################# + + if self.lowvram: + self.text_encoder.cpu() + self.text_encoder_2.cpu() + + shared.log.debug('DemoFusion: phase=1 denoising') + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + if self.lowvram: + self.vae.cpu() + self.unet.to(device) + + latents_for_view = latents + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + latents.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + anchor_mean = latents.mean() + anchor_std = latents.std() + del latents_for_view, latent_model_input, noise_pred, noise_pred_text, noise_pred_uncond + if self.lowvram: + latents = latents.cpu() + torch.cuda.empty_cache() + if output_type != "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if self.lowvram: + needs_upcasting = False # use madebyollin/sdxl-vae-fp16-fix in lowvram mode! + self.unet.cpu() + self.vae.to(device) + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + shared.log.debug('DemoFusion: phase=1 decoding') + if self.lowvram and multi_decoder: + current_width_height = self.unet.config.sample_size * self.vae_scale_factor + image = self.tiled_decode(latents, current_width_height, current_width_height) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + image = self.image_processor.postprocess(image, output_type=output_type) + output_images.append(image[0]) + else: + output_images.append(latents) + + ####################################################### Phase 2+ ##################################################### + for current_scale_num in range(2, scale_num + 1): + if self.lowvram: + latents = latents.to(device) + self.unet.to(device) + torch.cuda.empty_cache() + shared.log.debug(f'DemoFusion: phase={current_scale_num} denoising') + current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num + current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num + if height > width: + current_width = int(current_width * aspect_ratio) + else: + current_height = int(current_height * aspect_ratio) + + latents = F.interpolate(latents.to(device), size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)), mode='bicubic') + + noise_latents = [] + noise = torch.randn_like(latents) + for timestep in timesteps: + noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0)) + noise_latents.append(noise_latent) + latents = noise_latents[0] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu() + + c1 = cosine_factor ** cosine_scale_1 + latents = latents * (1 - c1) + noise_latents[i] * c1 + + ############################################# MultiDiffusion ############################################# + + views = self.get_views(current_height, current_width, stride=stride, window_size=self.unet.config.sample_size, random_jitter=True) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + + jitter_range = (self.unet.config.sample_size - stride) // 4 + latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), 'constant', 0) + + count_local = torch.zeros_like(latents_) + value_local = torch.zeros_like(latents_) + + for _j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) + + # get the latents corresponding to the current view coordinates + latents_for_view = torch.cat( + [ + latents_[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents_for_view + latent_model_input = ( + latent_model_input.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latent_model_input + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + add_text_embeds_input = torch.cat([add_text_embeds] * vb_size) + add_time_ids_input = [] + for h_start, _h_end, w_start, _w_end in batch_view: + add_time_ids_ = add_time_ids.clone() + add_time_ids_[:, 2] = h_start * self.vae_scale_factor + add_time_ids_[:, 3] = w_start * self.vae_scale_factor + add_time_ids_input.append(add_time_ids_) + add_time_ids_input = torch.cat(add_time_ids_input) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + self.scheduler._init_step_index(t) + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0] + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count_local[:, :, h_start:h_end, w_start:w_end] += 1 + + value_local = value_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor] + count_local = count_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor] + + c2 = cosine_factor ** cosine_scale_2 + + value += value_local / count_local * (1 - c2) + count += torch.ones_like(value_local) * (1 - c2) + + ############################################# Dilated Sampling ############################################# + + views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)] + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + + h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num + w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num + latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), 'constant', 0) + + count_global = torch.zeros_like(latents_) + value_global = torch.zeros_like(latents_) + + c3 = 0.99 * cosine_factor ** cosine_scale_3 + 1e-2 + std_, mean_ = latents_.std(), latents_.mean() + latents_gaussian = gaussian_filter(latents_, kernel_size=(2*current_scale_num-1), sigma=sigma*c3) + latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_ + + for _j, batch_view in enumerate(views_batch): + latents_for_view = torch.cat( + [ + latents_[:, :, h::current_scale_num, w::current_scale_num] + for h, w in batch_view + ] + ) + latents_for_view_gaussian = torch.cat( + [ + latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] + for h, w in batch_view + ] + ) + + vb_size = latents_for_view.size(0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents_for_view_gaussian + latent_model_input = ( + latent_model_input.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latent_model_input + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + add_text_embeds_input = torch.cat([add_text_embeds] * vb_size) + add_time_ids_input = torch.cat([add_time_ids] * vb_size) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + self.scheduler._init_step_index(t) + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0] + + # extract value from batch + for latents_view_denoised, (h, w) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised + count_global[:, :, h::current_scale_num, w::current_scale_num] += 1 + + c2 = cosine_factor ** cosine_scale_2 + + value_global = value_global[: ,:, h_pad:, w_pad:] + + value += value_global * c2 + count += torch.ones_like(value_global) * c2 + + ########################################################### + + latents = torch.where(count > 0, value / count, value) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + ######################################################################################################################################### + + latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean + if self.lowvram: + latents = latents.cpu() + torch.cuda.empty_cache() + if output_type != "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if self.lowvram: + needs_upcasting = False # use madebyollin/sdxl-vae-fp16-fix in lowvram mode! + self.unet.cpu() + self.vae.to(device) + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + shared.log.debug(f'DemoFusion: phase={current_scale_num} decoding') + if multi_decoder: + image = self.tiled_decode(latents, current_height, current_width) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + image = self.image_processor.postprocess(image, output_type=output_type) + output_images.append(image[0]) + else: + image = latents + output_images.append(image) + + # Offload all models + self.maybe_free_model_hooks() + output = ImagePipelineOutput(images=output_images) + return output + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): # pylint: disable=arguments-differ + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) # pylint: disable=protected-access + is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) # pylint: disable=protected-access + shared.log.info("Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again.") + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) + + +### Script definition + +class Script(scripts.Script): + def title(self): + return 'DemoFusion' + + def show(self, is_img2img): + return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False + + # return signature is array of gradio components + def ui(self, _is_img2img): + with gr.Row(): + cosine_scale_1 = gr.Slider(minimum=0, maximum=5, step=0.1, value=3, label="Cosine scale 1") + cosine_scale_2 = gr.Slider(minimum=0, maximum=5, step=0.1, value=1, label="Cosine scale 2") + cosine_scale_3 = gr.Slider(minimum=0, maximum=5, step=0.1, value=1, label="Cosine scale 3") + with gr.Row(): + view_batch_size = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Denoising batch size") + sigma = gr.Slider(minimum=0.1, maximum=1, step=0.1, value=0.8, label="Sigma") + stride = gr.Slider(minimum=8, maximum=96, step=8, value=64, label="Stride") + with gr.Row(): + multi_decoder = gr.Checkbox(label="Multi decoder", value=True) + return [cosine_scale_1, cosine_scale_2, cosine_scale_3, sigma, view_batch_size, stride, multi_decoder] + + def run(self, p: processing.StableDiffusionProcessing, cosine_scale_1, cosine_scale_2, cosine_scale_3, sigma, view_batch_size, stride, multi_decoder): # pylint: disable=arguments-differ + c = shared.sd_model.__class__.__name__ if shared.sd_model is not None else '' + if c != 'StableDiffusionXLPipeline': + shared.log.warning(f'DemoFusion: pipeline={c} required=StableDiffusionXLPipeline') + return None + p.task_args['cosine_scale_1'] = cosine_scale_1 + p.task_args['cosine_scale_2'] = cosine_scale_2 + p.task_args['cosine_scale_3'] = cosine_scale_3 + p.task_args['sigma'] = sigma + p.task_args['view_batch_size'] = view_batch_size + p.task_args['stride'] = stride + p.task_args['multi_decoder'] = multi_decoder + p.task_args['output_type'] = 'np' + p.task_args['low_vram'] = True + shared.log.debug(f'DemoFusion: {p.task_args}') + old_pipe = shared.sd_model + new_pipe = DemoFusionSDXLPipeline( + vae = shared.sd_model.vae, + text_encoder=shared.sd_model.text_encoder, + text_encoder_2=shared.sd_model.text_encoder_2, + tokenizer=shared.sd_model.tokenizer, + tokenizer_2=shared.sd_model.tokenizer_2, + unet=shared.sd_model.unet, + scheduler=shared.sd_model.scheduler, + force_zeros_for_empty_prompt=shared.opts.diffusers_force_zeros, + ) + shared.sd_model = new_pipe + if not ((shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) or (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram)): + shared.sd_model.to(shared.device) + sd_models.set_diffuser_options(shared.sd_model, vae=None, op='model') + shared.log.debug(f'DemoFusion create: pipeline={shared.sd_model.__class__.__name__}') + processed = processing.process_images(p) + shared.sd_model = old_pipe + return processed diff --git a/scripts/example.py b/scripts/example.py new file mode 100644 index 000000000..01eaef88b --- /dev/null +++ b/scripts/example.py @@ -0,0 +1,141 @@ +import gradio as gr +from diffusers.pipelines import StableDiffusionPipeline, StableDiffusionXLPipeline # pylint: disable=unused-import +from modules import shared, scripts, processing, sd_models + +""" +This is a simpler template for script for SD.Next that implements a custom pipeline +Items that can be added: +- Any pipeline already in diffusers + List of pipelines that can be directly used: +- Any pipeline for which diffusers definiotion exists and can be copied + List of pipelines with community definitions: +- Any custom pipeline that you create + +Author:: +- Your details + +Credits: +- Link to original implementation and author + +Contributions: +- Submit a PR on SD.Next GitHub repo to be included in /scripts +- Before submitting a PR, make sure to test your script thoroughly and that it passes code quality checks + Lint rules are part of SD.Next CI/CD pipeline + > pip install ruff pylint + > ruff scripts/example.py + > pylint scriptts/example.py +""" + +## Config + +# script title +title = 'Example' + +# is script available in txt2img tab +txt2img = False + +# is script available in img2img tab +img2img = False + +# is pipeline ok to run in pure latent mode without implicit conversions +# recommended so entire ecosystem can be used as-is, but requires that latent is in format that sdnext can understand +# some pipelines may not support this, in which case set to false and pipeline will implicitly do things like vae encode/decode on its own +latent = True + +# base pipeline class from which this pipeline is derived, most commonly 'StableDiffusionPipeline' or 'StableDiffusionXLPipeline' +pipeline_base = 'StableDiffusionPipeline' + +# class definition for this pipeline +# for built-in diffuser pipelines, simply import it from diffusers.pipelines above +# for example only, its set to same as base pipeline +# for community pipelines, copy class definition from community source code +# in which case only class definition code and required imports needs to be copied, not the entire source code +pipeline_class = StableDiffusionPipeline + +# pipeline args values are defined in ui method below, here we need to define their exact names +# they also have to be in the exact order as they are defined in ui +# note: variable names should be exactly as defined in pipeline_class.__call__ method +# if pipeline requires a param and its not provided, it will result in runtime error +# if you provide param that is not defined by pipeline, sdnext will strip it +params = ['test1', 'test2', 'test3', 'test4'] + + +### Script definition + +class Script(scripts.Script): + def title(self): + return title + + def show(self, is_img2img): + if shared.backend == shared.Backend.DIFFUSERS: + return img2img if is_img2img else txt2img + return False + + # Define UI for pipeline + def ui(self, _is_img2img): + ui_controls = [] + with gr.Row(): + ui_controls.append(gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Test1")) + ui_controls.append(gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Test2")) + with gr.Row(): + ui_controls.append(gr.Checkbox(label="Test3", value=True)) + with gr.Row(): + ui_controls.append(gr.Textbox(label="Test4", value="", placeholder="enter text here")) + with gr.Row(): + gr.HTML(' TypeError: StableDiffusionPipeline.__init__() missing 2 required positional arguments: 'safety_checker' and 'feature_extractor' + vae = shared.sd_model.vae, + text_encoder=shared.sd_model.text_encoder, + tokenizer=shared.sd_model.tokenizer, + unet=shared.sd_model.unet, + scheduler=shared.sd_model.scheduler, + safety_checker=shared.sd_model.safety_checker, + feature_extractor=shared.sd_model.feature_extractor, + ) + sd_models.copy_diffuser_options(shared.sd_model, orig_pipeline) # copy options from original pipeline + sd_models.set_diffuser_options(shared.sd_model) # set all model options such as fp16, offload, etc. + if not ((shared.opts.diffusers_model_cpu_offload or shared.cmd_opts.medvram) or (shared.opts.diffusers_seq_cpu_offload or shared.cmd_opts.lowvram)): + shared.sd_model.to(shared.device) # move pipeline if needed, but don't touch if its under automatic managment + + # if pipeline also needs a specific type, you can set it here, but not commonly needed + # shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) + + # prepare params + # all pipeline params go into p.task_args and are automatically handled by sdnext from there + for i in range(len(args)): + p.task_args[params[i]] = args[i] + + # you can also re-use existing params from `p` object if pipeline wants them, but under a different name + # for example, if pipeline expects 'image' param, but you want to use 'init_images' instead which is what img2img tab uses + # p.task_args['image'] = p.init_images[0] + + if not latent: + p.task_args['output_type'] = 'np' + shared.log.debug(f'{c}: args={p.task_args}') + + # if you need to run any preprocessing, this is the place to do it + + # run processing + processed: processing.Processed = processing.process_images(p) + + # if you need to run any postprocessing, this is the place to do it + # you dont need to handle saving, metadata, etc - sdnext will do it for you + + # restore original pipeline + shared.sd_model = orig_pipeline + return processed diff --git a/scripts/faceid.py b/scripts/faceid.py new file mode 100644 index 000000000..9e609a527 --- /dev/null +++ b/scripts/faceid.py @@ -0,0 +1,116 @@ +import os +import cv2 +import torch +import numpy as np +import gradio as gr +import diffusers +import huggingface_hub as hf +from modules import scripts, processing, shared, devices + + +app = None +try: + import onnxruntime + from insightface.app import FaceAnalysis + from ip_adapter.ip_adapter_faceid import IPAdapterFaceID + ok = True +except Exception as e: + shared.log.error(f'FaceID: {e}') + ok = False + + +class Script(scripts.Script): + def title(self): + return 'FaceID' + + def show(self, is_img2img): + return ok if shared.backend == shared.Backend.DIFFUSERS else False + + # return signature is array of gradio components + def ui(self, _is_img2img): + with gr.Row(): + scale = gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=1.0) + with gr.Row(): + image = gr.Image(image_mode='RGB', label='Image', source='upload', type='pil', width=512) + return [scale, image] + + def run(self, p: processing.StableDiffusionProcessing, scale, image): # pylint: disable=arguments-differ, unused-argument + if not ok: + shared.log.error('FaceID: missing dependencies') + return None + if image is None: + shared.log.error('FaceID: no init_images') + return None + if shared.sd_model_type != 'sd': + shared.log.error('FaceID: base model not supported') + return None + + global app # pylint: disable=global-statement + if app is None: + shared.log.debug(f"ONNX: device={onnxruntime.get_device()} providers={onnxruntime.get_available_providers()}") + app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + onnxruntime.set_default_logger_severity(3) + app.prepare(ctx_id=0, det_thresh=0.5, det_size=(640, 640)) + + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + faces = app.get(image) + if len(faces) == 0: + shared.log.error('FaceID: no faces found') + return None + for face in faces: + shared.log.debug(f'FaceID face: score={face.det_score:.2f} gender={"female" if face.gender==0 else "male"} age={face.age} bbox={face.bbox}') + embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + + ip_ckpt = "h94/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin" + shared.log.debug(f'FaceID model load: {ip_ckpt}') + folder, filename = os.path.split(ip_ckpt) + basename, _ext = os.path.splitext(filename) + model_path = hf.hf_hub_download(repo_id=folder, filename=filename, cache_dir=shared.opts.diffusers_dir) + if model_path is None: + shared.log.error(f'FaceID: model download failed: {ip_ckpt}') + return None + + processing.process_init(p) + shared.sd_model.scheduler = diffusers.DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + ip_model = IPAdapterFaceID(shared.sd_model, model_path, devices.device) + ip_model_dict = { + 'prompt': p.all_prompts[0], + 'negative_prompt': p.all_negative_prompts[0], + 'num_samples': p.batch_size, + 'width': p.width, + 'height': p.height, + 'num_inference_steps': p.steps, + 'scale': scale, + 'guidance_scale': p.cfg_scale, + 'seed': int(p.all_seeds[0]), + 'faceid_embeds': None, + } + shared.log.debug(f'FaceID args: {ip_model_dict}') + ip_model_dict['faceid_embeds'] = embeds + images = ip_model.generate(**ip_model_dict) + + ip_model = None + p.extra_generation_params["IP Adapter"] = f'{basename}:{scale}' + for i, face in enumerate(faces): + p.extra_generation_params[f"FaceID {i} score"] = f'{face.det_score:.2f}' + p.extra_generation_params[f"FaceID {i} gender"] = "female" if face.gender==0 else "male" + p.extra_generation_params[f"FaceID {i} age"] = face.age + processed = processing.Processed( + p, + images_list=images, + seed=p.seed, + subseed=p.subseed, + index_of_first_image=0, + ) + processed.info = processed.infotext(p, 0) + processed.infotexts = [processed.info] + devices.torch_gc() + return processed diff --git a/scripts/ipadapter.py b/scripts/ipadapter.py index 797ff2785..ab6b89f08 100644 --- a/scripts/ipadapter.py +++ b/scripts/ipadapter.py @@ -8,25 +8,27 @@ - SD/SDXL autodetect """ +import time import gradio as gr from modules import scripts, processing, shared, devices image_encoder = None +image_encoder_type = None loaded = None -ADAPTERS = [ - 'none', - 'models/ip-adapter_sd15', - 'models/ip-adapter_sd15_light', +ADAPTERS = { + 'None': 'none', + 'Base': 'ip-adapter_sd15', + 'Light': 'ip-adapter_sd15_light', + 'Plus': 'ip-adapter-plus_sd15', + 'Plus Face': 'ip-adapter-plus-face_sd15', + 'Full face': 'ip-adapter-full-face_sd15', + 'Base SXDL': 'ip-adapter_sdxl', # 'models/ip-adapter_sd15_vit-G', # RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x1024 and 1280x3072) - # 'models/ip-adapter-plus_sd15', # KeyError: 'proj.weight' - # 'models/ip-adapter-plus-face_sd15', # KeyError: 'proj.weight' - # 'models/ip-adapter-full-face_sd15', # KeyError: 'proj.weight' - 'sdxl_models/ip-adapter_sdxl', # 'sdxl_models/ip-adapter_sdxl_vit-h', # 'sdxl_models/ip-adapter-plus_sdxl_vit-h', # 'sdxl_models/ip-adapter-plus-face_sdxl_vit-h', -] +} class Script(scripts.Script): @@ -39,18 +41,25 @@ def show(self, is_img2img): def ui(self, _is_img2img): with gr.Accordion('IP Adapter', open=False, elem_id='ipadapter'): with gr.Row(): - adapter = gr.Dropdown(label='Adapter', choices=ADAPTERS, value='none') + adapter = gr.Dropdown(label='Adapter', choices=list(ADAPTERS), value='none') scale = gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=0.5) with gr.Row(): image = gr.Image(image_mode='RGB', label='Image', source='upload', type='pil', width=512) return [adapter, scale, image] def process(self, p: processing.StableDiffusionProcessing, adapter, scale, image): # pylint: disable=arguments-differ - import torch - from transformers import CLIPVisionModelWithProjection - + # overrides + adapter = ADAPTERS.get(adapter, None) + if hasattr(p, 'ip_adapter_name'): + adapter = p.ip_adapter_name + if hasattr(p, 'ip_adapter_scale'): + scale = p.ip_adapter_scale + if hasattr(p, 'ip_adapter_image'): + image = p.ip_adapter_image + if adapter is None: + return # init code - global loaded, image_encoder # pylint: disable=global-statement + global loaded, image_encoder, image_encoder_type # pylint: disable=global-statement if shared.sd_model is None: return if shared.backend != shared.Backend.DIFFUSERS: @@ -79,26 +88,32 @@ def process(self, p: processing.StableDiffusionProcessing, adapter, scale, image else: shared.log.error(f'IP adapter: unsupported model type: {shared.sd_model_type}') return - if image_encoder is None: + if image_encoder is None or image_encoder_type != shared.sd_model_type: try: - image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder=subfolder, torch_dtype=torch.float16, cache_dir=shared.opts.diffusers_dir, use_safetensors=True).to(devices.device) + from transformers import CLIPVisionModelWithProjection + image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder=subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True).to(devices.device) + image_encoder_type = shared.sd_model_type except Exception as e: shared.log.error(f'IP adapter: failed to load image encoder: {e}') return # main code - subfolder, model = adapter.split('/') - if model != loaded or getattr(shared.sd_model.unet.config, 'encoder_hid_dim_type', None) is None: + subfolder = 'models' if 'sd15' in adapter else 'sdxl_models' + if adapter != loaded or getattr(shared.sd_model.unet.config, 'encoder_hid_dim_type', None) is None: + t0 = time.time() if loaded is not None: shared.log.debug('IP adapter: reset attention processor') shared.sd_model.unet.set_default_attn_processor() loaded = None - shared.log.info(f'IP adapter load: adapter="{model}" scale={scale} image={image}') + else: + shared.log.debug('IP adapter: load attention processor') shared.sd_model.image_encoder = image_encoder - shared.sd_model.load_ip_adapter("h94/IP-Adapter", subfolder=subfolder, weight_name=f'{model}.safetensors') - loaded = model + shared.sd_model.load_ip_adapter("h94/IP-Adapter", subfolder=subfolder, weight_name=f'{adapter}.safetensors') + t1 = time.time() + shared.log.info(f'IP adapter load: adapter="{adapter}" scale={scale} image={image} time={t1-t0:.2f}') + loaded = adapter else: - shared.log.debug(f'IP adapter cache: adapter="{model}" scale={scale} image={image}') + shared.log.debug(f'IP adapter cache: adapter="{adapter}" scale={scale} image={image}') shared.sd_model.set_ip_adapter_scale(scale) p.task_args['ip_adapter_image'] = p.batch_size * [image] p.extra_generation_params["IP Adapter"] = f'{adapter}:{scale}' diff --git a/scripts/postprocessing_video.py b/scripts/postprocessing_video.py new file mode 100644 index 000000000..f34036961 --- /dev/null +++ b/scripts/postprocessing_video.py @@ -0,0 +1,47 @@ +import gradio as gr +import modules.images +from modules import scripts_postprocessing + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Video" + + def ui(self): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + loop = gr.Checkbox(label='Loop', value=True, visible=False) + pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False) + change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False) + with gr.Row(): + filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change]) + return { + "filename": filename, + "video_type": video_type, + "duration": duration, + "loop": loop, + "pad": pad, + "interpolate": interpolate, + "scale": scale, + "change": change, + } + + def postprocess(self, images, filename, video_type, duration, loop, pad, interpolate, scale, change): # pylint: disable=arguments-differ + filename = filename.strip() + if video_type == 'None' or len(filename) == 0 or images is None or len(images) < 2: + return + modules.images.save_video(p=None, filename=filename, images=images, video_type=video_type, duration=duration, loop=loop, pad=pad, interpolate=interpolate, scale=scale, change=change) diff --git a/scripts/stablevideodiffusion.py b/scripts/stablevideodiffusion.py index 16077cc2c..3e20c4373 100644 --- a/scripts/stablevideodiffusion.py +++ b/scripts/stablevideodiffusion.py @@ -1,6 +1,7 @@ """ Additional params for StableVideoDiffusion """ + import torch import gradio as gr from modules import scripts, processing, shared, sd_models, images @@ -44,7 +45,9 @@ def video_type_change(video_type): return [num_frames, override_resolution, min_guidance_scale, max_guidance_scale, decode_chunk_size, motion_bucket_id, noise_aug_strength, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] def run(self, p: processing.StableDiffusionProcessing, num_frames, override_resolution, min_guidance_scale, max_guidance_scale, decode_chunk_size, motion_bucket_id, noise_aug_strength, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument - if shared.sd_model is None or shared.sd_model.__class__.__name__ != 'StableVideoDiffusionPipeline': + c = shared.sd_model.__class__.__name__ if shared.sd_model is not None else '' + if c != 'StableVideoDiffusionPipeline' and c != 'TextToVideoSDPipeline': + shared.log.error(f'StableVideo: model selected={c} required=StableVideoDiffusion') return None if hasattr(p, 'init_images') and len(p.init_images) > 0: if override_resolution: @@ -53,9 +56,13 @@ def run(self, p: processing.StableDiffusionProcessing, num_frames, override_reso p.task_args['image'] = images.resize_image(resize_mode=2, im=p.init_images[0], width=p.width, height=p.height, upscaler_name=None, output_type='pil') else: p.task_args['image'] = p.init_images[0] - p.ops.append('svd') + p.ops.append('stablevideo') p.do_not_save_grid = True - p.sampler_name = 'Default' # svd does not support non-default sampler + if c == 'StableVideoDiffusionPipeline': + p.sampler_name = 'Default' # svd does not support non-default sampler + p.task_args['output_type'] = 'np' + else: + p.task_args['output_type'] = 'pil' p.task_args['generator'] = torch.manual_seed(p.seed) # svd does not support gpu based generator p.task_args['width'] = p.width p.task_args['height'] = p.height @@ -66,7 +73,6 @@ def run(self, p: processing.StableDiffusionProcessing, num_frames, override_reso p.task_args['num_inference_steps'] = p.steps p.task_args['min_guidance_scale'] = min_guidance_scale p.task_args['max_guidance_scale'] = max_guidance_scale - p.task_args['output_type'] = 'np' shared.log.debug(f'StableVideo: args={p.task_args}') shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) processed = processing.process_images(p) diff --git a/scripts/text2video.py b/scripts/text2video.py new file mode 100644 index 000000000..34571b4ae --- /dev/null +++ b/scripts/text2video.py @@ -0,0 +1,106 @@ +""" +Additional params for Text-to-Video + + +TODO: +- Video-to-Video upscaling: , +""" + +import gradio as gr +from modules import scripts, processing, shared, images, sd_models, modelloader + + +MODELS = [ + {'name': 'None'}, + {'name': 'ModelScope v1.7b', 'path': 'damo-vilab/text-to-video-ms-1.7b', 'params': [16,320,320]}, + {'name': 'ZeroScope v1', 'path': 'cerspense/zeroscope_v1_320s', 'params': [16,320,320]}, + {'name': 'ZeroScope v1.1', 'path': 'cerspense/zeroscope_v1-1_320s', 'params': [16,320,320]}, + {'name': 'ZeroScope v2', 'path': 'cerspense/zeroscope_v2_576w', 'params': [24,576,320]}, + {'name': 'ZeroScope v2 Dark', 'path': 'cerspense/zeroscope_v2_dark_30x448x256', 'params': [24,448,256]}, + {'name': 'Potat v1', 'path': 'camenduru/potat1', 'params': [24,1024,576]}, +] + + +class Script(scripts.Script): + def title(self): + return 'Text-to-Video' + + def show(self, is_img2img): + return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False + + # return signature is array of gradio components + def ui(self, _is_img2img): + + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + + def model_info_change(model_name): + if model_name == 'None': + return gr.update(value='') + else: + model = next(m for m in MODELS if m['name'] == model_name) + return gr.update(value=f'   frames: {model["params"][0]} size: {model["params"][1]}x{model["params"][2]} link') + + with gr.Row(): + model_name = gr.Dropdown(label='Model', value='None', choices=[m['name'] for m in MODELS]) + with gr.Row(): + model_info = gr.HTML() + model_name.change(fn=model_info_change, inputs=[model_name], outputs=[model_info]) + with gr.Row(): + use_default = gr.Checkbox(label='Use defaults', value=True) + num_frames = gr.Slider(label='Frames', minimum=1, maximum=50, step=1, value=0) + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) + mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + return [model_name, use_default, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] + + def run(self, p: processing.StableDiffusionProcessing, model_name, use_default, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + if model_name == 'None': + return + model = [m for m in MODELS if m['name'] == model_name][0] + shared.log.debug(f'Text2Video: model={model} defaults={use_default} frames={num_frames}, video={video_type} duration={duration} loop={gif_loop} pad={mp4_pad} interpolate={mp4_interpolate}') + + if model['path'] in shared.opts.sd_model_checkpoint: + shared.log.debug(f'Text2Video cached: model={shared.opts.sd_model_checkpoint}') + else: + checkpoint = sd_models.get_closet_checkpoint_match(model['path']) + if checkpoint is None: + shared.log.debug(f'Text2Video downloading: model={model["path"]}') + checkpoint = modelloader.download_diffusers_model(hub_id=model['path']) + sd_models.list_models() + if checkpoint is None: + shared.log.error(f'Text2Video: failed to find model={model["path"]}') + return + shared.log.debug(f'Text2Video loading: model={checkpoint}') + shared.opts.sd_model_checkpoint = checkpoint + sd_models.reload_model_weights(op='model') + + p.ops.append('text2video') + p.do_not_save_grid = True + if use_default: + p.task_args['num_frames'] = model['params'][0] + p.width = model['params'][1] + p.height = model['params'][2] + elif num_frames > 0: + p.task_args['num_frames'] = num_frames + else: + shared.log.error('Text2Video: invalid number of frames') + return + + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) + shared.log.debug(f'Text2Video: args={p.task_args}') + processed = processing.process_images(p) + + if video_type != 'None': + images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate) + return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 15d11e555..f6fc5d553 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -269,6 +269,8 @@ def __init__(self, *args, **kwargs): AxisOption("[FreeU] 2nd stage backbone factor", float, apply_setting('freeu_b2')), AxisOption("[FreeU] 1st stage skip factor", float, apply_setting('freeu_s1')), AxisOption("[FreeU] 2nd stage skip factor", float, apply_setting('freeu_s2')), + AxisOption("[IP adapter] Name", str, apply_field('ip_adapter_name'), cost=1.0), + AxisOption("[IP adapter] Scale", float, apply_field('ip_adapter_scale')), ] diff --git a/webui.py b/webui.py index d352bed72..76c1cb0a7 100644 --- a/webui.py +++ b/webui.py @@ -266,6 +266,7 @@ def start_ui(): favicon_path='html/logo.ico', allowed_paths=[os.path.dirname(__file__), cmd_opts.data_dir], app_kwargs=fastapi_args, + _frontend=not cmd_opts.share, ) if cmd_opts.data_dir is not None: ui_tempdir.register_tmp_file(shared.demo, os.path.join(cmd_opts.data_dir, 'x')) @@ -315,7 +316,8 @@ def webui(restart=False): for k, v in modules.script_callbacks.callback_map.items(): shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}') log.info(f"Startup time: {timer.startup.summary()}") - debug = log.info if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None + debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None + debug('Trace: SCRIPTS') debug('Loaded scripts:') for m in modules.scripts.scripts_data: debug(f' {m}') diff --git a/webui.sh b/webui.sh index 692cc36c1..d88c8ec53 100755 --- a/webui.sh +++ b/webui.sh @@ -80,22 +80,16 @@ else exit 1 fi -#Set OneAPI environmet if it's not set by the user -if ([[ "$@" == *"--use-ipex"* ]] || [[ -d "/opt/intel/oneapi" ]] || [[ ! -z "$ONEAPI_ROOT" ]]) && [ ! -x "$(command -v sycl-ls)" ] +if [ -d "$(realpath "$venv_dir")/lib/" ] then - echo "Setting OneAPI environment" - if [[ -z "$ONEAPI_ROOT" ]] - then - ONEAPI_ROOT=/opt/intel/oneapi - fi - source $ONEAPI_ROOT/setvars.sh + export LD_LIBRARY_PATH=$(realpath "$venv_dir")/lib/:$LD_LIBRARY_PATH fi if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] then echo "Launching accelerate launch.py..." exec accelerate launch --num_cpu_threads_per_process=6 launch.py "$@" -elif [[ "$@" == *"--use-ipex"* ]] && [[ -z "${first_launch}" ]] && [ -x "$(command -v ipexrun)" ] && [ -x "$(command -v sycl-ls)" ] +elif [[ "$@" == *"--use-ipex"* ]] && [[ -z "${first_launch}" ]] && [ -x "$(command -v ipexrun)" ] && [[ -z "${DISABLE_IPEXRUN}" ]] then echo "Launching ipexrun launch.py..." exec ipexrun --multi-task-manager 'taskset' --memory-allocator 'jemalloc' launch.py "$@" diff --git a/wiki b/wiki index 931082304..9a9713ecb 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit 931082304da70e7683d993e28a1765c8cb6844c1 +Subproject commit 9a9713ecbbebaeebe9b15f895d6cab566d3a79a7