Skip to content

speed_comparison

Takumi Ando edited this page Dec 22, 2020 · 4 revisions

Speed Comparison

As you folks have known, many histopathological deep learning systems load slide data as small patches. If you think of your AWS bill, you might think like, "Keep WSI as it is now, and do not crop patches because it takes a lot of space and money and also it's bothering!" But, GPU instance should eat your money the most. You need to speed up your iteration speed by all means! Time is money!

Result

  • Patched with wsiprocess > OpenSlide > pyvips

mm... do I need to change wsiprocess to use OpenSlide?

Environment

  • torch==1.7
  • torchvision==0.8
  • openslide-python==1.1.2
  • pyvips==2.1.12
  • libvips==8.10.1

Simple Comparison

  • Compare the speed to load the patches as a batch of torch.float32 from 0. to 1..
# Import some packages
from itertools import product
from glob import glob
from tqdm import tqdm
from openslide.deepzoom import DeepZoomGenerator
import openslide
import pyvips
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as NF
import torchvision.transforms.functional as TF
from torchvision import transforms
from torchvision.io import read_image
# Default settings
tile_size = 512
batch_size = 16
wsi = "CMU-1.ndpi"

With OpenSlide

def speed_openslide():
    level = 16
    slide = openslide.OpenSlide(wsi)
    dz = DeepZoomGenerator(slide, tile_size=tile_size, overlap=0)
    rows, cols = dz.level_tiles[-1]

    for row, col in tqdm(product(range(rows), range(cols)), total=rows*cols):
        tile = dz.get_tile(level, (row, col))
        tensor = TF.to_tensor(tile)
# Result
100%|███████████████████████████████████████| 7500/7500 [01:20<00:00, 93.52it/s]

With pyvips

def speed_pyvips():
    slide = pyvips.Image.new_from_file(wsi)
    x_coords = [i for i in range(0, slide.width, tile_size)]
    y_coords = [i for i in range(0, slide.height, tile_size)]
    for x, y in tqdm(product(x_coords, y_coords), total=7500):
        width = min(512, slide.width-x)
        height = min(512, slide.height-y)
        tile = slide.crop(x, y, width, height)
        tile = np.ndarray(
            buffer=tile.write_to_memory(),
            dtype=np.uint8,
            shape=[tile.height, tile.width, tile.bands]
        )
        tensor = TF.to_tensor(tile)
# Result
100%|██████████████████████████████████████| 7500/7500 [00:50<00:00, 147.10it/s]

Load from files

  • It's no wonder this works in a flash.
# Extract patches beforehand 
wsiprocess none CMU-1.ndpi -pw 512 -ph 512 -ow 0 -oh 0 -mm 0-255
def speed_from_patch():
    paths = glob("CMU-1/patches/foreground/*.jpg")
    for path in tqdm(paths):
        tensor = read_image(path).type(torch.float32) / 255
# Result
100%|██████████████████████████████████████| 7500/7500 [00:12<00:00, 591.65it/s]

Comparison with PyTorch Dataloader

With OpenSlide

class DatasetWithOpenSlide(Dataset):

    def __init__(self, wsi, tile_size):
        self.slide = openslide.OpenSlide(wsi)
        self.tile_size = tile_size
        self.dz = DeepZoomGenerator(
            self.slide,
            tile_size=self.tile_size,
            overlap=0)
        self.tiles_x, self.tiles_y = self.dz.level_tiles[-1]
        self.deepest_layer = self.dz.level_count - 1

    def __len__(self):
        return self.tiles_x * self.tiles_y

    def __getitem__(self, idx):
        x = idx % self.tiles_x
        y = idx // self.tiles_x
        img = self.dz.get_tile(self.deepest_layer, (x, y))
        tensor = TF.to_tensor(img)
        if tensor.size(1) != self.tile_size or tensor.size(2) != self.tile_size:
            tensor = NF.interpolate(tensor.unsqueeze(0), (self.tile_size)).squeeze(0)
        return tensor


def speed_DataLoaderWithOpenSlide():
    dataset = DatasetWithOpenSlide(wsi, tile_size=tile_size)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
    for img in tqdm(data_loader):
        pass
# Result
# 29.86 * 16(batch) = 477.76(patch/s)
100%|█████████████████████████████████████████| 469/469 [00:15<00:00, 29.86it/s]

With pyvips

class DatasetWithVips(Dataset):

    def __init__(self, wsi, tile_size, transform):
        self.slide = pyvips.Image.new_from_file(wsi)
        self.tile_size = tile_size
        x_coords = [i for i in range(0, self.slide.width, tile_size)]
        y_coords = [i for i in range(0, self.slide.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))
        self.transform = transform

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        width = min(512, self.slide.width-x)
        height = min(512, self.slide.height-y)
        tile = self.slide.crop(x, y, width, height)
        tile = np.ndarray(
            buffer=tile.write_to_memory(),
            dtype=np.uint8,
            shape=[tile.height, tile.width, tile.bands]
        )
        tensor = TF.to_tensor(tile)[:3, :, :]
        if tensor.size(1) != self.tile_size or tensor.size(2) != self.tile_size:
            tensor = NF.interpolate(tensor.unsqueeze(0), (self.tile_size)).squeeze(0)
        return tensor


def speed_DataLoaderWithVips():
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    dataset = DatasetWithVips(wsi, tile_size=tile_size, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
    for img in tqdm(data_loader):
        pass
# Result
# 20.91 * 16(batch) = 334.56(patch/s)
100%|█████████████████████████████████████████| 469/469 [00:22<00:00, 20.91it/s]

Load from files

class DatasetFromPatch(Dataset):

    def __init__(self, root, tile_size):
        self.paths = glob(f"{root}/*.jpg")
        self.tile_size = tile_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        tensor = read_image(self.paths[idx])
        if tensor.size(1) != self.tile_size or tensor.size(2) != self.tile_size:
            tensor = NF.interpolate(tensor.unsqueeze(0), (self.tile_size)).squeeze(0)
        return tensor.type(torch.float32).div_(255)


def speed_DataLoaderFromPatch():
    root = "CMU-1/patches/foreground"
    dataset = DatasetFromPatch(root, tile_size=tile_size)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
    for img in tqdm(data_loader):
        assert img.dtype == torch.float32
        assert len(img.shape) == 4
        assert torch.all(0 <= img)
        assert torch.all(img <= 1)
# Result
# 53.24 * 16(batch) = 851.84(patch/s)
100%|█████████████████████████████████████████| 469/469 [00:08<00:00, 53.24it/s]