Skip to content

Commit

Permalink
support bnb quantization during load
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <mandic00@live.com>
  • Loading branch information
vladmandic committed Oct 22, 2024
1 parent 63bedde commit f191134
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 28 deletions.
36 changes: 30 additions & 6 deletions modules/model_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,35 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu
return transformer, text_encoder_2


def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2):
repo_id = sd_models.path_to_repo(checkpoint_info.name)
cache_dir=shared.opts.diffusers_dir
if len(shared.opts.bnb_quantization) > 0 and (transformer is None or text_encoder_2 is None):
from modules.model_quant import load_bnb
load_bnb('Load model: type=FLUX')
try:
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype
)
if 'Model' in shared.opts.bnb_quantization and transformer is None:
transformer = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
if 'Text Encoder' in shared.opts.bnb_quantization and text_encoder_2 is None:
text_encoder_2 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
except Exception as e:
shared.log.error(f"Load model: type=FLUX failed quantize using BnB: {e}")
transformer, text_encoder_2 = None, None
if debug:
from modules import errors
errors.display(e, 'FLUX:')
return transformer, text_encoder_2


def load_flux_gguf(file_path):
transformer = None
model_te.install_gguf()
Expand Down Expand Up @@ -257,6 +286,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch

# initialize pipeline with pre-loaded components
components = {}
transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2)
if transformer is not None:
components['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
Expand All @@ -276,10 +306,4 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
shared.log.warning(f'Load model: type=FLUX component={c} dtype={components[c].dtype} cast dtype={devices.dtype}')
components[c] = components[c].to(dtype=devices.dtype)
pipe = diffusers.FluxPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **components, **diffusers_load_config)
try:
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux"] = diffusers.FluxPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux"] = diffusers.FluxImg2ImgPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux"] = diffusers.FluxInpaintPipeline
except Exception:
pass
return pipe
6 changes: 3 additions & 3 deletions modules/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def load_bnb(msg='', silent=False):
if bnb is not None:
return bnb
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: module=bitsandbytes fn={fn}') # pylint: disable=protected-access
log.debug(f'Quantization: type=bitsandbytes fn={fn}') # pylint: disable=protected-access
install('bitsandbytes', quiet=True)
try:
import bitsandbytes
Expand All @@ -30,10 +30,10 @@ def load_quanto(msg='', silent=False):
if quanto is not None:
return quanto
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: module=quanto fn={fn}') # pylint: disable=protected-access
log.debug(f'Quantization: type=quanto fn={fn}') # pylint: disable=protected-access
install('optimum-quanto', quiet=True)
try:
from optimum import quanto as optimum_quanto
from optimum import quanto as optimum_quanto # pylint: disable=no-name-in-module
quanto = optimum_quanto
return quanto
except Exception as e:
Expand Down
33 changes: 28 additions & 5 deletions modules/model_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import transformers


default_repo_id = 'stabilityai/stable-diffusion-3-medium'


def load_sd3(checkpoint_info, cache_dir=None, config=None):
from modules import devices, modelloader, sd_models
from modules import shared, devices, modelloader, sd_models
repo_id = sd_models.path_to_repo(checkpoint_info.name)
dtype = devices.dtype
kwargs = {}
Expand All @@ -14,24 +17,24 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
if fn_size < 5e9:
kwargs = {
'text_encoder': transformers.CLIPTextModelWithProjection.from_pretrained(
repo_id,
default_repo_id,
subfolder='text_encoder',
cache_dir=cache_dir,
torch_dtype=dtype,
),
'text_encoder_2': transformers.CLIPTextModelWithProjection.from_pretrained(
repo_id,
default_repo_id,
subfolder='text_encoder_2',
cache_dir=cache_dir,
torch_dtype=dtype,
),
'tokenizer': transformers.CLIPTokenizer.from_pretrained(
repo_id,
default_repo_id,
subfolder='tokenizer',
cache_dir=cache_dir,
),
'tokenizer_2': transformers.CLIPTokenizer.from_pretrained(
repo_id,
default_repo_id,
subfolder='tokenizer_2',
cache_dir=cache_dir,
),
Expand All @@ -47,6 +50,26 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
modelloader.hf_login()
loader = diffusers.StableDiffusion3Pipeline.from_pretrained
kwargs['variant'] = 'fp16'

if len(shared.opts.bnb_quantization) > 0:
from modules.model_quant import load_bnb
load_bnb('Load model: type=SD3')
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype
)
if 'Model' in shared.opts.bnb_quantization:
transformer = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
kwargs['transformer'] = transformer
if 'Text Encoder' in shared.opts.bnb_quantization:
te3 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
kwargs['text_encoder_3'] = te3

pipe = loader(
repo_id,
torch_dtype=dtype,
Expand Down
3 changes: 2 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ def apply_balanced_offload_to_module(pipe):
module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True)
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
except Exception as e:
shared.log.error(f'Balanced offload: module={module_name} {e}')
if 'bitsandbytes' not in str(e):
shared.log.error(f'Balanced offload: module={module_name} {e}')
devices.torch_gc(fast=True)

apply_balanced_offload_to_module(sd_model)
Expand Down
31 changes: 18 additions & 13 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,24 +490,11 @@ def get_default_modes():
"cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"),
"deep_cache_interval": OptionInfo(3, "DeepCache cache interval", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),

"quant_sep": OptionInfo("<h2>Model Quantization</h2>", "", gr.HTML, {"visible": native}),
"quant_shuffle_weights": OptionInfo(False, "Shuffle the weights between GPU and CPU when quantizing", gr.Checkbox, {"visible": native}),
"nncf_compress_weights": OptionInfo([], "NNCF int8 compression enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
"optimum_quanto_weights": OptionInfo([], "Optimum.quanto quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
"optimum_quanto_weights_type": OptionInfo("qint8", "Optimum.quanto quantization type", gr.Radio, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}),
"optimum_quanto_activations_type": OptionInfo("none", "Optimum.quanto quantization activations ", gr.Radio, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}),
"torchao_quantization": OptionInfo([], "TorchAO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
"torchao_quantization_type": OptionInfo("int8", "TorchAO quantization type", gr.Radio, {"choices": ["int8+act", "int8", "int4", "fp8+act", "fp8", "fpx"], "visible": native}),

"ipex_sep": OptionInfo("<h2>IPEX</h2>", "", gr.HTML, {"visible": devices.backend == "ipex"}),
"ipex_optimize": OptionInfo([], "IPEX Optimize for Intel GPUs", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"], "visible": devices.backend == "ipex"}),

"openvino_sep": OptionInfo("<h2>OpenVINO</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}), # pylint: disable=E0606
"nncf_quantize": OptionInfo([], "OpenVINO Quantize Models with NNCF", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}),
"nncf_quant_mode": OptionInfo("INT8", "OpenVINO quantization mode for NNCF", gr.Radio, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_mode": OptionInfo("INT8", "OpenVINO compress mode for NNCF", gr.Radio, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_raito": OptionInfo(1.0, "OpenVINO compress ratio for NNCF", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
"openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
"openvino_disable_memory_cleanup": OptionInfo(True, "OpenVINO disable memory cleanup after compile", gr.Checkbox, {"visible": cmd_opts.use_openvino}),

Expand All @@ -522,6 +509,24 @@ def get_default_modes():
"olive_cache_optimized": OptionInfo(True, 'Olive cache optimized models'),
}))

options_templates.update(options_section(('quantization', "Quantization Settings"), {
"bnb_quantization": OptionInfo([], "BnB quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
"bnb_quantization_type": OptionInfo("fp8", "BnB quantization type", gr.Radio, {"choices": ['nf4', 'fp8', 'fp4'], "visible": native}),
"bnb_quantization_storage": OptionInfo("uint8", "BnB quantization storage", gr.Radio, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"], "visible": native}),
"optimum_quanto_weights": OptionInfo([], "Optimum.quanto quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
"optimum_quanto_weights_type": OptionInfo("qint8", "Optimum.quanto quantization type", gr.Radio, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}),
"optimum_quanto_activations_type": OptionInfo("none", "Optimum.quanto quantization activations ", gr.Radio, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}),
"torchao_quantization": OptionInfo([], "TorchAO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
"torchao_quantization_type": OptionInfo("int8", "TorchAO quantization type", gr.Radio, {"choices": ["int8+act", "int8", "int4", "fp8+act", "fp8", "fpx"], "visible": native}),
"nncf_compress_weights": OptionInfo([], "NNCF compression enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
"nncf_compress_weights_mode": OptionInfo("INT8", "NNCF compress mode", gr.Radio, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'] if cmd_opts.use_openvino else ['INT8']}),
"nncf_compress_weights_raito": OptionInfo(1.0, "NNCF compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
"nncf_quantize": OptionInfo([], "NNCF OpenVINO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}),
"nncf_quant_mode": OptionInfo("INT8", "NNCF OpenVINO quantization mode", gr.Radio, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}),

"quant_shuffle_weights": OptionInfo(False, "Shuffle the weights between GPU and CPU when quantizing", gr.Checkbox, {"visible": native}),
}))

options_templates.update(options_section(('advanced', "Inference Settings"), {
"token_merging_sep": OptionInfo("<h2>Token merging</h2>", "", gr.HTML),
"token_merging_method": OptionInfo("None", "Token merging method", gr.Radio, {"choices": ['None', 'ToMe', 'ToDo']}),
Expand Down

0 comments on commit f191134

Please sign in to comment.