Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarifications on Decoder and Embeddings #3

Open
aksg87 opened this issue Jan 19, 2022 · 4 comments
Open

Clarifications on Decoder and Embeddings #3

aksg87 opened this issue Jan 19, 2022 · 4 comments

Comments

@aksg87
Copy link

aksg87 commented Jan 19, 2022

@robogast

Happy to report I was able to train a VQ-VAE using a dataset. Very cool to see - and kudos for the nice Tensorboard outputs you have in place! 😎

  1. Do you have any suggestions or code for randomly sampling from the decoder in a generative fashion?

  2. Also, If you have a summary of these files and their purpose, that would be very helpful. I would be happy to do a PR with some comments in the repository if that would be helpful.

Questions on:
calc_ssim_from_checkpoint.py # does this calculate SSIM across the dataset ❓
decode_embeddings.py # Specifications for db_path ❓
extract_embeddings.py # Does this save embedding to disk ❓

Ran successfully:
plot_from_checkpoint.py # plots a forward pass from a random sample ✅
train.py # trains a model ✅

Much appreciated!
-Akshay

@robogast
Copy link
Member

Hi! Answers to your questions:

  1. You cannot sample the decoder directly, you need to train an autoregressive prior (i.e. pixelcnn, pixelsnail, ViT, ..., maybe using a discrete denoising model would be cool...) on the embeddings obtained by putting your dataset through the encoder.
    You then sample your autoregressive model for embeddings, and put those embeddings through the decoder. See the original VQ-VAE paper: https://arxiv.org/pdf/1711.00937.pdf
    • calc_ssim_from_checkpoint -> I simply had not added SSIM as a metric to tensorboard yet when I wrote this script, so this file can be ignored (or removed) now.
    • decode_embeddings.py -> the db_path are the generated embeddings by your autoregressive model, so you don't have them right now.
    • extract_embeddings.py -> yes, this file in principle takes your model + dataset and created the embeddings which should be used as training input for your autoregressive model.
    • As a general note, these three files are scripts and not intended as library files, and thus should be treated as such (i.e. low quality control, hardcoding a lot of stuff).

Nice to see that you're progressing :)

@aksg87
Copy link
Author

aksg87 commented Jan 21, 2022

@robogast - Appreciate all of the information! Need to review the paper again :)

I look forward to trying the other scripts and posting how things go!

@aksg87
Copy link
Author

aksg87 commented Jan 25, 2022

Hi @robogast

Your comments make much more sense now after reviewing the literature further :)

This is a nice overview from AI Epiphany!

https://www.youtube.com/watch?v=VZFVUrYcig0&t=1736s

@aksg87 aksg87 changed the title Sampling Randomly from Decoder? Clarifications on Decoder and Embeddings Jan 30, 2022
@aksg87
Copy link
Author

aksg87 commented Jan 30, 2022

Hi @robogast

I was trying to better understand encoding_idx. My understanding is that this is the last item in each of the 3 bottle neck layers? Curious why we throw the rest of the information away?

Thanks in advance!
-Akshay

def extract_samples(model, dataloader):
    model.eval()
    model.to(GPU)

    with torch.no_grad():
        for sample, _ in dataloader:
            sample = sample.to(GPU)
            *_, encoding_idx = zip(*model.encode(sample))
            yield encoding_idx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants