diff --git a/pyproject.toml b/pyproject.toml index be15845..9e199b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ exclude = [] [project] name = "vspropainter" -version = "1.2.0" +version = "1.2.1" description = "ProPainter function for VapourSynth" readme = "README.md" requires-python = ">=3.10" diff --git a/vspropainter/__init__.py b/vspropainter/__init__.py index 3170956..eb3e5ce 100644 --- a/vspropainter/__init__.py +++ b/vspropainter/__init__.py @@ -24,7 +24,7 @@ from vspropainter.propainter_render import ModelProPainterIn, ModelProPainterOut from vspropainter.propainter_utils import * -__version__ = "1.2.0" +__version__ = "1.2.1" os.environ["CUDA_MODULE_LOADING"] = "LAZY" diff --git a/vspropainter/model/misc.py b/vspropainter/model/misc.py index df9974c..12a77b2 100644 --- a/vspropainter/model/misc.py +++ b/vspropainter/model/misc.py @@ -8,13 +8,17 @@ import numpy as np from os import path as osp + def constant_init(module, val, bias=0): if hasattr(module, 'weight') and module.weight is not None: nn.init.constant_(module.weight, val) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) + initialized_logger = {} + + def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): """Get the root logger. The logger will be initialized if it has not been initialized. By default a @@ -45,7 +49,7 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None logger.setLevel(log_level) # add file handler # file_handler = logging.FileHandler(log_file, 'w') - file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log file_handler.setFormatter(logging.Formatter(format_str)) file_handler.setLevel(log_level) logger.addHandler(file_handler) @@ -53,16 +57,24 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None return logger -#IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ -IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^(\d+)\.(\d+)\.(\d+)([\w\d\.].*)?$",\ - torch.__version__)[0][:3])] >= [1, 12, 0] +def IS_HIGH_VERSION() -> bool: + # check torch version for mps support. + # see: https://huggingface.co/docs/diffusers/optimization/mps + try: + t_ver = [int(m) for m in + list(re.findall(r"^(\d+)\.(\d+)\.(\d+)([\w\d\.].*)?$", torch.__version__)[0][:3])] >= [1, 13, 0] + except Exception: + return False + return t_ver + def gpu_is_available(): - if IS_HIGH_VERSION: + if IS_HIGH_VERSION(): if torch.backends.mps.is_available(): return True return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + def get_device(gpu_id=None): if gpu_id is None: gpu_str = '' @@ -71,10 +83,11 @@ def get_device(gpu_id=None): else: raise TypeError('Input should be int value.') - if IS_HIGH_VERSION: + if IS_HIGH_VERSION(): if torch.backends.mps.is_available(): - return torch.device('mps'+gpu_str) - return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + return torch.device('mps' + gpu_str) + return torch.device( + 'cuda' + gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') def set_random_seed(seed): @@ -129,4 +142,4 @@ def _scandir(dir_path, suffix, recursive): else: continue - return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file + return _scandir(dir_path, suffix=suffix, recursive=recursive)