diff --git a/src/icon_registration/network_wrappers.py b/src/icon_registration/network_wrappers.py index e2258f6..59430e6 100644 --- a/src/icon_registration/network_wrappers.py +++ b/src/icon_registration/network_wrappers.py @@ -42,11 +42,19 @@ def as_function(self, image): Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...] into a tensor of intensities. """ - - return lambda coordinates: compute_warped_image_multiNC( - image, coordinates, self.spacing, 1 - ) - + def image_as_function(coordinates): + if hasattr(coordinates, "isIdentity") and coordinate.shape == image.shape: + return image + + return compute_warped_image_multiNC( + image, coordinates, self.spacing, 1 + ) + return image_as_function + def tag_identity_map(self): + self.identity_map.isIdentity = True + for child in self.children(): + if isinstance(child, RegistrationModule): + child.tag_identity_map() def assign_identity_map(self, input_shape, parents_identity_map=None): self.input_shape = np.array(input_shape) self.input_shape[0] = 1