Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to parquet dataloading, sampling, batch sampling #742

Merged
merged 7 commits into from
Sep 11, 2024

Conversation

pweigel
Copy link
Collaborator

@pweigel pweigel commented Aug 22, 2024

This pull request implements some changes to improve the speed of loading data with ParquetDataset using samplers, and a few minor changes.

The first change is the ability to add a Sampler and/or a BatchSampler into GraphNeTDataModule when the DataLoader is created. Generally, the sampler and batch sampler will need access to the dataset before being used in the dataloader, so they are created in _create_dataloader.

The other large change is the implementation of RandomChunkSampler and LenMatchBatchSampler from the 2nd place Kaggle competition winners. Essentially, the RandomChunkSampler will pick files randomly and the LenMatchBatchSampler will group events of similar length into batches. This is particularly useful for transformers since less time is spent padding/truncating sequences.

Right now, the sampler and batch sampler settings live in the dataloader kwargs. For example:

train_dataloader_kwargs={
    "batch_size": 256,
    "num_workers": 4,
    "multiprocessing_context": "spawn",
    "sampler": RandomChunkSampler,
    "sampler_kwargs": {},
    "batch_sampler": LenMatchBatchSampler,
    "batch_sampler_kwargs": {
        "batch_size": config["batch_size"],
        "num_workers": 4,
        "drop_last": True,
        "multiprocessing_context": "spawn",
    }, 
},

This has some pros and cons. The method for creating these objects in GraphNeTDataModule in _create_dataloader means that the train, test, and validation dataloaders can each have their own samplers/batch samplers. However, this means that the extra keys in the kwargs must be removed before creating the dataloader.

I should clarify that the implementation of the LenMatchBatchSampler is a little different than the original implementation. I noticed that the batch sampler was a CPU bottleneck for training when using large batch sizes. I think this is because batch samplers run on the main process, so I implemented a method to use multiprocessing to speed that up. This means that you will have fewer cores loading in files, but I found that forming batches was actually the main bottleneck for my use case. When num_workers > 0, the files are grouped into segments for each worker to process. For example, when chunks_per_segment = 8, each worker will length-match batches from 8 files and return any leftover/incomplete batches. When all of the complete batches are finished, the leftovers are combined to form batches. The num_workers = 0 case should be very close to the original implementation.

In terms of performance, I saw a 2-4x increase in training speed (batches/sec) when training large transformer models w/ large batch sizes compared to the normal sampling methods.

A smaller change is renaming the "batch" labeled variables in ParquetDataset to "chunk" to clarify that they are not batches. In fact, I don't think any sort of batching should occur in the dataloaders.

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @pweigel thank you for this very clean pull request 🚀 !

I think including the option to pass samplers makes a lot of sense, and that this would all be a great addition to the repo. I am curious though - do you know how this methodology compares to just moving length-based matching to a custom collate_fn?

There is quite a few lines that exceed the 79 characters limit ( see the CodeClimate report here). Could you shorten those, please? The remaining CodeClimate flags can be ignored.

@RasmusOrsoe RasmusOrsoe self-requested a review September 5, 2024 12:41
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@pweigel
Copy link
Collaborator Author

pweigel commented Sep 6, 2024

I think including the option to pass samplers makes a lot of sense, and that this would all be a great addition to the repo. I am curious though - do you know how this methodology compares to just moving length-based matching to a custom collate_fn?

I can try this out, but I think the issue I ran into would happen with collate_fn too (if I correctly understand how it works), which is the length-matching for large batches on the main process. I was using quite large batches (128 to 512), so having multiple processes generating them overcame that bottleneck. I'm not sure if that can be done with collate_fn, so it's part of the implementation of the batch sampler.

@pweigel pweigel merged commit e8140c5 into graphnet-team:main Sep 11, 2024
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants