From e54981e7fcfd24263354d9c11fe70cb44457a594 Mon Sep 17 00:00:00 2001 From: Tom Runia Date: Fri, 14 Dec 2018 10:43:31 +0100 Subject: [PATCH] device check --- steerable/SCFpyr_PyTorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/steerable/SCFpyr_PyTorch.py b/steerable/SCFpyr_PyTorch.py index a4eb814..ec46e63 100644 --- a/steerable/SCFpyr_PyTorch.py +++ b/steerable/SCFpyr_PyTorch.py @@ -74,6 +74,8 @@ def build(self, im_batch): Returns: pyramid: list containing torch.Tensor objects storing the pyramid ''' + + assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device) assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32' assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'