Skip to content

Latest commit

 

History

History
46 lines (38 loc) · 9.21 KB

File metadata and controls

46 lines (38 loc) · 9.21 KB

Using-GAN-fill-missing-part-of-Handwritten-Digits

Python NumPy Pandas Jupyter TensorFlow OpenCV Keras

How to build a neural network to fill the missing part of a handwritten digit using GANs

Steps by step procedure

  1. Importing all of the dependencies
  2. We will be using numpy, matplotlib, keras, tensorflow and tqdm package in this exercise. Here, TensorFlow is used as the backend for Keras.Use !pip install to download all dependencies. For the MNIST data, we will be using the dataset available in the keras module with a simple import keyword.
  3. Exploring the data
  4. We will load the MNIST data into our session from the keras module with mnist.load_data(). After doing so, we will print the shape and the size of the dataset, as well as the number of classes and unique labels in the dataset.
  5. Masking/inserting noise
  6. For the needs of this project, we need to simulate a dataset of incomplete digits. So, let’s write a function to mask small regions in the original image to form the noised dataset.The idea is to mask an 8*8 region of the image with the top-left corner of the mask falling between the 9th and 13th pixel (between index 8 and 12) along both the x and y axis of the image. This is to make sure that we are always masking around the center part of the image.
  7. Reshaping
  8. Reshape the original dataset and the noised dataset to a shape of 6000028281. This is important since the 2D convolutions expect to receive images of a shape of 28*28*1
  9. MNIST classifier
  10. To start off with modeling, let’s build a simple convolutional neural network (CNN) digit classifier.
    The first layer is a convolution layer that has 32 filters of a shape of 33, with relu activation and Dropout as the regularizer. The second layer is a convolution layer that has 64 filters of a shape of 33, with relu activation and Dropout as the regularizer. The third layer is a convolution layer that has 128 filters of a shape of 3*3, with relu activation and Dropout as the regularizer, which is finally flattened. The fourth layer is a Dense layer of 1024 neurons with relu activation. The final layer is a Dense layer with 10 neurons corresponding to the 10 classes in the MNIST dataset, and the activation used here is softmax, batch_size is set to 128, the optimizer used is adam, and validation_split is set to 0.2. This means that 20% of the training set will be used as the validation set.
  11. Building the GAN model components
  12. With the idea that the final GAN model will be able to fill in the part of the image that is missing (masked), let’s define the generator.
    • Plotting the training – part 1
    • During each epoch, the following function plots 9 generated images. For comparison, it will also plot the corresponding 9 original target images and 9 noised input images. We need to use the upscale function we’ve defined when plotting to make sure the images are scaled to range between 0 and 255, so that you do not encounter issues when plotting
    • Plotting the training – part 2
    • Let’s define another function that plots the images generated during each epoch. To reflect the difference, we will also include the original and the masked/noised images in the plot.The top row contains the original images, the middle row contains the masked images, and the bottom row contains the generated images. The plot has 12 rows with the sequence, row 1 – original, row 2 – masked, row3 – generated, row 4 – original, row5 – masked,…, row 12 – generated.
  13. Training loop
  14. Now we are at the most important part of the code; the part where all of the functions we previously defined will be used. The following are the steps:

    Load the generator by calling the img_generator() function. Load the discriminator by calling the img_discriminator() function and compile it with the binary cross-entropy loss and optimizer as optimizer_d, which we have defined under the hyperparameters section. Feed the generator and the discriminator to the dcgan() function and compile it with the binary cross-entropy loss and optimizer as optimizer_g, which we have defined under the hyperparameters section. Create a new batch of original images and masked images. Generate new fake images by feeding the batch of masked images to the generator. Concatenate the original and generated images so that the first 128 images are all original and the next 128 images are all fake. It is important that you do not shuffle the data here, otherwise it will be hard to train. Label the generated images as 0 and original images as 0.9 instead of 1. This is one-sided label smoothing on the original images. The reason for using label smoothing is to make the network resilient to adversarial examples. It’s called one-sided because we are smoothing labels only for the real images. Set discriminator.trainable to True to enable training of the discriminator and feed this set of 256 images and their corresponding labels to the discriminator for classification. Now, set discriminator.trainable to False and feed a new batch of 128 masked images labeled as 1 to the GAN (DCGAN) for classification. It is important to set discriminator.trainable to False to make sure the discriminator is not getting trained while training the generator. Repeat steps 4 through 7 for the desired number of epochs. We have placed the plot_generated_images_combined() function and the generated_images_plot() function to get a plot generated by both functions after the first iteration in the first epoch and after the end of each epoch. Training loop Now we are at the most important part of the code; the part where all of the functions we previously defined will be used. The following are the steps:

    Load the generator by calling the img_generator() function. Load the discriminator by calling the img_discriminator() function and compile it with the binary cross-entropy loss and optimizer as optimizer_d, which we have defined under the hyperparameters section. Feed the generator and the discriminator to the dcgan() function and compile it with the binary cross-entropy loss and optimizer as optimizer_g, which we have defined under the hyperparameters section. Create a new batch of original images and masked images. Generate new fake images by feeding the batch of masked images to the generator. Concatenate the original and generated images so that the first 128 images are all original and the next 128 images are all fake. It is important that you do not shuffle the data here, otherwise it will be hard to train. Label the generated images as 0 and original images as 0.9 instead of 1. This is one-sided label smoothing on the original images. The reason for using label smoothing is to make the network resilient to adversarial examples. It’s called one-sided because we are smoothing labels only for the real images. Set discriminator.trainable to True to enable training of the discriminator and feed this set of 256 images and their corresponding labels to the discriminator for classification. Now, set discriminator.trainable to False and feed a new batch of 128 masked images labeled as 1 to the GAN (DCGAN) for classification. It is important to set discriminator.trainable to False to make sure the discriminator is not getting trained while training the generator. Repeat steps 4 through 7 for the desired number of epochs. We have placed the plot_generated_images_combined() function and the generated_images_plot() function to get a plot generated by both functions after the first iteration in the first epoch and after the end of each epoch.

    Feel free to place these plot functions according to the frequency of plots you need displayed

  15. Predictions
  16. CNN classifier predictions on the noised and generated images We will call the generator on the masked MNIST test data to generate images, that is, fill in the missing part of the digits. The MNIST CNN classifier is 87.47% accurate on the generated data

Implementation of GAN on MNIST-Dataset done by Arpita Halder, CSE Department,BBIT(MAKAUT) Under Dr. Pratik Chattopadhyay, Assistant Prof., Dept. of Computer Sci., & Engg., IIT (BHU) During Summer Research Internship