Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training PixelCNN unclear #5

Open
Arksyd96 opened this issue Jun 3, 2024 · 1 comment
Open

Training PixelCNN unclear #5

Arksyd96 opened this issue Jun 3, 2024 · 1 comment

Comments

@Arksyd96
Copy link

Arksyd96 commented Jun 3, 2024

Hi,

I'm using your implementation to generate MRIs. I have trained a VQ-VAE to reconstruct 3D MRIs, but I am unsure about which vectors to use for training the PixelCNN for sampling.

I attempted to understand your LMDB implementation, but it would take me a significant amount of time to fully grasp it. I'm not clear on what exactly is being stored in the LMDB database.

Given that the VQ-VAE encoder outputs multiple quantization vectors (one for each encoding block), what should be the specific input for the PixelCNN?

x = torch.randn(4, 3, 128, 128, 64).to('cuda')
decoded, (commitment_loss, quantizations, encoding_idx) = vqvae(x)

I think i'll have to modify the LMDB data module part.

Thank you!

@robogast
Copy link
Member

robogast commented Jun 3, 2024

Hi! It has been a while since I've worked on this project, so my memory is not too sharp.
I'll try to see what I can do to help.

As far as I can see/remember, the input to the PixelCNN is a list of 3 dimensional one-hot encoded matrices (tensors), see how I unpack them in the PixelCNN:

if len(batch) == 1 or not self.use_conditioning: # no condition present
data = batch[0]
condition = None
b, c, *dim = data.size()
else:
data, condition = batch
condition = condition.squeeze(dim=1)
b, c, *dim = data.size()
condition = F.interpolate(
idx_to_one_hot(condition, num_classes=self.condition_dim),
size=dim, mode='trilinear'
)

The whole pickling/txn context etc is just fluff needed for LMDB to work.

The reason I use LMDB was that at the time it was the only database implementation available to support both memmapped arrays and concurrent reads (which is important for computational efficiency when running multi-node, which I did for sampling the full 512x512x128 volumes)

As said I'm not entirely up-to-date on these kinds of workloads anymore, but two thoughts:

  • I'm guessing the PixelCNN part is not super relevant anymore, the problem with it is that it's sampling one-by-one, instead of amortized sampling such as in diffusion models. This means that inherently it cannot scale to large volumes without a lot of computational power.
  • The Vector Quantization model might still be nice, but especially VQ part should be much more steered, an should be replaced by a more well developed model such as from lucidrains: https://github.com/lucidrains/vector-quantize-pytorch

If you have more questions let me know.

Robert Jan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants