Skip to content

Commit

Permalink
Optimize pathology inference - support multi-gpu (#712)
Browse files Browse the repository at this point in the history
* Optimize pathology inference - support multi-gpu

Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com>

* Optimize pathology inference - support multi-gpu

Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com>
  • Loading branch information
SachidanandAlle authored Mar 31, 2022
1 parent f9e37fe commit 3f4fef4
Show file tree
Hide file tree
Showing 19 changed files with 200 additions and 302 deletions.
2 changes: 1 addition & 1 deletion monailabel/endpoints/wsi_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_wsi_inference(
request["image"] = session.image
request["session"] = session.to_json()

logger.info(f"WSI Infer Request: {request}")
logger.debug(f"WSI Infer Request: {request}")

result = instance.infer_wsi(request)
if result is None:
Expand Down
41 changes: 30 additions & 11 deletions monailabel/interfaces/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import platform
import random
import shutil
import tempfile
import time
Expand Down Expand Up @@ -592,7 +593,8 @@ def infer_wsi(self, request, datastore=None):
f"WSI/Inference Task is not Initialized. There is no model '{model}' available",
)

image = request["image"]
img_id = request["image"]
image = img_id
request_c = copy.deepcopy(task.config())
request_c.update(request)
request = request_c
Expand All @@ -609,38 +611,55 @@ def infer_wsi(self, request, datastore=None):
image = datastore.get_image_uri(request["image"])

start = time.time()
logger.info(f"WSI Infer Request (final): {request}")
infer_tasks = create_infer_wsi_tasks(request, image)
if len(infer_tasks) > 1:
logger.info(f"WSI Infer Request (final): {request}")

logger.debug(f"Total WSI Tasks: {len(infer_tasks)}")
request["logging"] = request.get("logging", "WARNING" if len(infer_tasks) > 1 else "INFO")

multi_gpu = request.get("multi_gpu", False)
multi_gpu = request.get("multi_gpu", True)
multi_gpus = request.get("gpus", "all")
gpus = (
list(range(torch.cuda.device_count())) if not multi_gpus or multi_gpus == "all" else multi_gpus.split(",")
)
device_ids = [f"cuda:{id}" for id in gpus] if multi_gpu else [request.get("device", "cuda")]
logger.info(f"MultiGpu: {multi_gpu}; Using Device(s): {device_ids}")

res_json = {"annotations": [None] * len(infer_tasks)}
for idx, t in enumerate(infer_tasks):
t["logging"] = request["logging"]
t["device"] = device_ids[idx % len(device_ids)]
t["device"] = (
device_ids[idx % len(device_ids)]
if len(infer_tasks) > 1
else device_ids[random.randint(0, len(device_ids) - 1)]
)

if len(infer_tasks) > 1 and len(device_ids) > 1:
with ThreadPoolExecutor(max_workers=len(device_ids), thread_name_prefix="WSI Infer") as executor:
total = len(infer_tasks)
max_workers = request.get("max_workers", len(device_ids))

if len(infer_tasks) > 1 and (max_workers == 0 or max_workers > 1):
logger.info(f"MultiGpu: {multi_gpu}; Using Device(s): {device_ids}; Max Workers: {max_workers}")
futures = {}
with ThreadPoolExecutor(max_workers if max_workers else None, "WSI Infer") as executor:
for t in infer_tasks:
tid = t["id"]
future = executor.submit(self._run_infer_wsi_task, t)
futures[t["id"]] = t, executor.submit(self._run_infer_wsi_task, t)

for tid, (t, future) in futures.items():
res = future.result()
res_json["annotations"][tid] = res
logger.info(f"{tid} => {len(res_json)} / {len(infer_tasks)}; Latencies: {res.get('latencies')}")
finished = len([a for a in res_json["annotations"] if a])
logger.info(
f"{img_id} => {tid} => {t['device']} => {finished} / {total}; Latencies: {res.get('latencies')}"
)
else:
for t in infer_tasks:
tid = t["id"]
res = self._run_infer_wsi_task(t)
res_json["annotations"][tid] = res
logger.info(f"{tid} => {len(res_json)} / {len(infer_tasks)}; Latencies: {res.get('latencies')}")
finished = len([a for a in res_json["annotations"] if a])
logger.info(
f"{img_id} => {tid} => {t['device']} => {finished} / {total}; Latencies: {res.get('latencies')}"
)

latency_total = time.time() - start
logger.debug("WSI Infer Time Taken: {:.4f}".format(latency_total))
Expand Down
48 changes: 26 additions & 22 deletions monailabel/interfaces/tasks/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
from monailabel.interfaces.utils.transform import run_transforms
from monailabel.transform.writer import Writer
from monailabel.utils.others.generic import device_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +58,7 @@ class InferTask:

def __init__(
self,
path: Union[str, Sequence[str]],
path: Union[None, str, Sequence[str]],
network: Union[None, Any],
type: Union[str, InferType],
labels: Union[str, None, Sequence[str], Dict[Any, Any]],
Expand All @@ -70,6 +71,7 @@ def __init__(
config: Union[None, Dict[str, Any]] = None,
load_strict: bool = False,
roi_size=None,
preload=False,
):
"""
:param path: Model File Path. Supports multiple paths to support versions (Last item will be picked as latest)
Expand All @@ -84,8 +86,9 @@ def __init__(
:param config: K,V pairs to be part of user config
:param load_strict: Load model in strict mode
:param roi_size: ROI size for scanning window inference
:param preload: Preload model/network on all available GPU devices
"""
self.path = path
self.path = [] if not path else [path] if isinstance(path, str) else path
self.network = network
self.type = type
self.labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
Expand All @@ -99,8 +102,9 @@ def __init__(
self.roi_size = roi_size

self._networks: Dict = {}

self._config: Dict[str, Any] = {
# "device": "cuda",
# "device": device_list(),
# "result_extension": None,
# "result_dtype": None,
# "result_compress": False
Expand All @@ -111,6 +115,11 @@ def __init__(
if config:
self._config.update(config)

if preload:
for device in device_list():
logger.info(f"Preload Network for device: {device}")
self._get_network(device)

def info(self) -> Dict[str, Any]:
return {
"type": self.type,
Expand All @@ -127,19 +136,19 @@ def is_valid(self) -> bool:
if self.network or self.type == InferType.SCRIBBLES:
return True

paths = [self.path] if isinstance(self.path, str) else self.path
paths = self.path
for path in reversed(paths):
if os.path.exists(path):
if path and os.path.exists(path):
return True
return False

def get_path(self):
if not self.path:
return None

paths = [self.path] if isinstance(self.path, str) else self.path
paths = self.path
for path in reversed(paths):
if os.path.exists(path):
if path and os.path.exists(path):
return path
return None

Expand Down Expand Up @@ -247,7 +256,10 @@ def __call__(self, request) -> Tuple[str, Dict[str, Any]]:

# device
device = req.get("device", "cuda")
if device.startswith("cuda") and not torch.cuda.is_available():
device = "cpu"
req["device"] = device

logger.setLevel(req.get("logging", "INFO").upper())
logger.info(f"Infer Request (final): {req}")

Expand Down Expand Up @@ -346,9 +358,6 @@ def _get_network(self, device):
f"Model Path ({self.path}) does not exist/valid",
)

if device.startswith("cuda") and not torch.cuda.is_available():
device = "cpu"

cached = self._networks.get(device)
statbuf = os.stat(path) if path else None
network = None
Expand All @@ -360,16 +369,15 @@ def _get_network(self, device):

if network is None:
if self.network:
network = self.network
network = copy.deepcopy(self.network)
network.to(torch.device(device))

if path:
checkpoint = torch.load(path, map_location=torch.device(device))
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)
network.load_state_dict(model_state_dict, strict=self.load_strict)
else:
network = torch.jit.load(path, map_location=torch.device(device))

if device.startswith("cuda"):
network = network.cuda(device)
network = torch.jit.load(path, map_location=torch.device(device)).to(torch.device)

network.eval()
self._networks[device] = (network, statbuf.st_mtime if statbuf else 0)
Expand All @@ -388,22 +396,18 @@ def run_inferer(self, data, convert_to_batch=True, device="cuda"):
"""

inferer = self.inferer(data)
logger.info("Inferer:: {} => {}".format(inferer.__class__.__name__, inferer.__dict__))

device = device if device else "cuda"
if device.startswith("cuda") and not torch.cuda.is_available():
device = "cpu"
logger.info("Inferer:: {} => {} => {}".format(device, inferer.__class__.__name__, inferer.__dict__))

network = self._get_network(device)
if network:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
if device.startswith("cuda"):
inputs = inputs.cuda(torch.device(device))
inputs = inputs.to(torch.device(device))

with torch.no_grad():
outputs = inferer(inputs, network)

if device.startswith("cuda"):
torch.cuda.empty_cache()

Expand Down
2 changes: 1 addition & 1 deletion monailabel/interfaces/utils/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_infer_wsi_tasks(request, image):
rows = ceil(h / tile_size[1]) # ROW

if rows * cols > 1:
logger.info(f"Total Tiles to infer {rows} x {cols}: {rows * cols}")
logger.info(f"Total Tiles to infer {rows} x {cols}: {rows * cols}; Dimensions: {w} x {h}")

infer_tasks = []
count = 0
Expand Down
2 changes: 1 addition & 1 deletion monailabel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def args_start_server(self, parser):
parser.add_argument("--ssl_certfile", default=None, type=str, help="SSL certificate file")
parser.add_argument("--ssl_keyfile_password", default=None, type=str, help="SSL key file password")
parser.add_argument("--ssl_ca_certs", default=None, type=str, help="CA certificates file")
parser.add_argument("--workers", default=1, type=int, help="Number of worker processes")
parser.add_argument("--workers", default=None, type=int, help="Number of worker processes")
parser.add_argument("--limit_concurrency", default=None, type=int, help="Max concurrent connections")
parser.add_argument("--access_log", action="store_true", help="Enable access log")

Expand Down
4 changes: 2 additions & 2 deletions monailabel/tasks/scoring/epistemic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import time
Expand Down Expand Up @@ -115,7 +115,7 @@ def _load_model(self, path, network):
logger.info(f"Using {model_file} for running Epistemic")
model_ts = int(os.stat(model_file).st_mtime) if model_file and os.path.exists(model_file) else 1
if network:
model = network
model = copy.deepcopy(network)
if model_file:
if torch.cuda.is_available():
checkpoint = torch.load(model_file)
Expand Down
4 changes: 2 additions & 2 deletions monailabel/tasks/scoring/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import time
Expand Down Expand Up @@ -126,7 +126,7 @@ def _load_model(self, path, network):
logger.info(f"Using {model_file} for running TTA")
model_ts = int(os.stat(model_file).st_mtime) if model_file and os.path.exists(model_file) else 1
if network:
model = network
model = copy.deepcopy(network)
if model_file:
checkpoint = torch.load(model_file)
model_state_dict = checkpoint.get("model", checkpoint)
Expand Down
4 changes: 2 additions & 2 deletions monailabel/tasks/train/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from monailabel.interfaces.datastore import Datastore
from monailabel.interfaces.tasks.train import TrainTask
from monailabel.tasks.train.handler import PublishStatsAndModel, prepare_stats
from monailabel.utils.others.generic import remove_file
from monailabel.utils.others.generic import device_list, remove_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
self._config = {
"name": "train_01",
"pretrained": True,
"device": "cuda",
"device": device_list(),
"max_epochs": 50,
"early_stop_patience": -1,
"val_split": 0.2,
Expand Down
10 changes: 9 additions & 1 deletion monailabel/utils/others/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import subprocess
import time

import torch.cuda
import torch
from monai.apps import download_url

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,3 +173,11 @@ def download_file(url, path, delay=1, skip_on_exists=True):
download_url(url, path)
if delay > 0:
time.sleep(delay)


def device_list():
devices = [] if torch.cuda.is_available() else ["cpu"]
for i in range(torch.cuda.device_count()):
devices.append(f"cuda:{i}")

return devices
Loading

0 comments on commit 3f4fef4

Please sign in to comment.