Official PyTorch implementation of "Learning Large-scale Neural Fields via Context Pruned Meta-Learning" (NeurIPS 2023) by Jihoon Tack, Subin Kim, Sihyun Yu, Jaeho Lee, Jinwoo Shin, Jonathan Richard Schwarz.
TL;DR: We propose an efficient meta-learning framework for scalable neural fields learning that involves online data pruning of the context set.
conda create -n gradncp python=3.8 -y
conda activate gradncp
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install einops pyyaml tensorboardX tensorboard natsort pyspng av pytorch_msssim lpips
- Dataset path
/data
, one can change the path indata.dataset.py
(e.g.,DATA_PATH = './PATH_TO_DATA'
) - Download CelebA, CelebA-HQ, AFHQ, Imagenette-320, ImageNet, Text, UCF-101, Librispeech, ERA5
# Learnit
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/maml_celeba.yaml
# Ours
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/ours_celeba.yaml
- Example of
<PATH TO CHECKPOINT>
:./logs/maml_celeba/best.pth
# Learnit
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba.yaml --load_path ./logs/xxxx/best.model
# Ours (CelebaA) Example
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba_ours.yaml --load_path ./logs/xxxx/best.model
This code is mainly built upon JAX Learnit, JAX Functa, PyTorch Siren, PyTorch MetaSDF, PyTorch Meta-SparseINR, and PyTorch COIN++ repositories.
@inproceedings{tack2023learning,
title={Learning Large-scale Neural Fields via Context Pruned Meta-Learning},
author={Tack, Jihoon and Kim, Subin and Yu, Sihyun and Lee, Jaeho and Shin, Jinwoo and Schwarz, Jonathan Richard},
booktitle={Advances in Neural Information Processing Systems},
year={2023}
}