Skip to content

Commit

Permalink
Improved torch version detection in misc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dan64 committed Oct 29, 2024
1 parent c993544 commit 9c44ddd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude = []

[project]
name = "vspropainter"
version = "1.2.0"
version = "1.2.1"
description = "ProPainter function for VapourSynth"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
2 changes: 1 addition & 1 deletion vspropainter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vspropainter.propainter_render import ModelProPainterIn, ModelProPainterOut
from vspropainter.propainter_utils import *

__version__ = "1.2.0"
__version__ = "1.2.1"

os.environ["CUDA_MODULE_LOADING"] = "LAZY"

Expand Down
31 changes: 22 additions & 9 deletions vspropainter/model/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import numpy as np
from os import path as osp


def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)


initialized_logger = {}


def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
Expand Down Expand Up @@ -45,24 +49,32 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
logger.setLevel(log_level)
# add file handler
# file_handler = logging.FileHandler(log_file, 'w')
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger


#IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^(\d+)\.(\d+)\.(\d+)([\w\d\.].*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def IS_HIGH_VERSION() -> bool:
# check torch version for mps support.
# see: https://huggingface.co/docs/diffusers/optimization/mps
try:
t_ver = [int(m) for m in
list(re.findall(r"^(\d+)\.(\d+)\.(\d+)([\w\d\.].*)?$", torch.__version__)[0][:3])] >= [1, 13, 0]
except Exception:
return False
return t_ver


def gpu_is_available():
if IS_HIGH_VERSION:
if IS_HIGH_VERSION():
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False


def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
Expand All @@ -71,10 +83,11 @@ def get_device(gpu_id=None):
else:
raise TypeError('Input should be int value.')

if IS_HIGH_VERSION:
if IS_HIGH_VERSION():
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
return torch.device('mps' + gpu_str)
return torch.device(
'cuda' + gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')


def set_random_seed(seed):
Expand Down Expand Up @@ -129,4 +142,4 @@ def _scandir(dir_path, suffix, recursive):
else:
continue

return _scandir(dir_path, suffix=suffix, recursive=recursive)
return _scandir(dir_path, suffix=suffix, recursive=recursive)

0 comments on commit 9c44ddd

Please sign in to comment.