diff --git a/examples/text_to_image.py b/examples/text_to_image.py index 3dec42e..2f4ad79 100644 --- a/examples/text_to_image.py +++ b/examples/text_to_image.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from high_order_implicit_representation.networks import GenNet from pytorch_lightning.callbacks import LearningRateMonitor -from high_order_implicit_representation.rendering import ImageGenerator +from high_order_implicit_representation.rendering import Text2ImageGenerator from high_order_implicit_representation.single_image_dataset import ( image_to_dataset, Text2ImageDataModule @@ -34,7 +34,7 @@ def run_implicit_images(cfg: DictConfig): data_module = Text2ImageDataModule( filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations ) - image_generator = ImageGenerator( + image_generator = Text2ImageGenerator( filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size ) lr_monitor = LearningRateMonitor(logging_interval="epoch") @@ -58,7 +58,7 @@ def run_implicit_images(cfg: DictConfig): checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}" logger.info(f"checkpoint_path {checkpoint_path}") - model = GenerativeNetwork.load_from_checkpoint(checkpoint_path) + model = GenNet.load_from_checkpoint(checkpoint_path) model.eval() image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}" diff --git a/high_order_implicit_representation/rendering.py b/high_order_implicit_representation/rendering.py index d94874b..cba27ed 100644 --- a/high_order_implicit_representation/rendering.py +++ b/high_order_implicit_representation/rendering.py @@ -11,12 +11,15 @@ from high_order_implicit_representation.single_image_dataset import ( image_neighborhood_dataset, image_to_dataset, + Text2ImageDataset ) import math import matplotlib.pyplot as plt import io import PIL from torchvision import transforms +from torch.utils.data import DataLoader + logger = logging.getLogger(__name__) default_size = [64, 64] @@ -238,3 +241,63 @@ def on_train_epoch_end( trainer.logger.experiment.add_image( f"image", image, global_step=trainer.global_step ) + +class Text2ImageGenerator(Callback): + def __init__(self, filename, rotations, batch_size): + self._dataset = Text2ImageDataset(filename, rotations=rotations) + self._dataloader = DataLoader(self._dataset, batch_size=batch_size, shuffle=False) + self._batch_size = batch_size + + @rank_zero_only + def on_train_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + pl_module.eval() + with torch.no_grad(): + self._inputs = self._inputs.to(device=pl_module.device) + + y_hat_list = [] + + for batch in self._dataloader: + res = pl_module( + self._inputs[ + batch * self._batch_size : (batch + 1) * self._batch_size + ] + ) + y_hat_list.append(res.detach().cpu()) + y_hat = torch.cat(y_hat_list) + + ans = y_hat.reshape( + self._image.shape[0], self._image.shape[1], self._image.shape[2] + ) + ans = 0.5 * (ans + 1.0) + + f, axarr = plt.subplots(1, 2) + axarr[0].imshow(ans.detach().cpu().numpy()) + axarr[0].set_title("fit") + axarr[1].imshow(self._image.cpu()) + axarr[1].set_title("original") + + for i in range(2): + axarr[i].axes.get_xaxis().set_visible(False) + axarr[i].axes.get_yaxis().set_visible(False) + + buf = io.BytesIO() + plt.savefig( + buf, + dpi="figure", + format=None, + metadata=None, + bbox_inches=None, + pad_inches=0.1, + facecolor="auto", + edgecolor="auto", + backend=None, + ) + buf.seek(0) + image = PIL.Image.open(buf) + image = transforms.ToTensor()(image) + + trainer.logger.experiment.add_image( + f"image", image, global_step=trainer.global_step + )