-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to parquet dataloading, sampling, batch sampling #742
Improvements to parquet dataloading, sampling, batch sampling #742
Conversation
* Renamed the batch variables in ParquetDataset to chunk variables * Implemented RandomChunkSampler and LenMatchBatchSampler w/ modifications
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @pweigel thank you for this very clean pull request 🚀 !
I think including the option to pass samplers makes a lot of sense, and that this would all be a great addition to the repo. I am curious though - do you know how this methodology compares to just moving length-based matching to a custom collate_fn
?
There is quite a few lines that exceed the 79 characters limit ( see the CodeClimate report here). Could you shorten those, please? The remaining CodeClimate flags can be ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
I can try this out, but I think the issue I ran into would happen with |
This pull request implements some changes to improve the speed of loading data with
ParquetDataset
using samplers, and a few minor changes.The first change is the ability to add a
Sampler
and/or aBatchSampler
intoGraphNeTDataModule
when theDataLoader
is created. Generally, the sampler and batch sampler will need access to the dataset before being used in the dataloader, so they are created in_create_dataloader
.The other large change is the implementation of
RandomChunkSampler
andLenMatchBatchSampler
from the 2nd place Kaggle competition winners. Essentially, theRandomChunkSampler
will pick files randomly and theLenMatchBatchSampler
will group events of similar length into batches. This is particularly useful for transformers since less time is spent padding/truncating sequences.Right now, the sampler and batch sampler settings live in the dataloader kwargs. For example:
This has some pros and cons. The method for creating these objects in
GraphNeTDataModule
in_create_dataloader
means that the train, test, and validation dataloaders can each have their own samplers/batch samplers. However, this means that the extra keys in the kwargs must be removed before creating the dataloader.I should clarify that the implementation of the
LenMatchBatchSampler
is a little different than the original implementation. I noticed that the batch sampler was a CPU bottleneck for training when using large batch sizes. I think this is because batch samplers run on the main process, so I implemented a method to use multiprocessing to speed that up. This means that you will have fewer cores loading in files, but I found that forming batches was actually the main bottleneck for my use case. Whennum_workers > 0
, the files are grouped into segments for each worker to process. For example, whenchunks_per_segment = 8
, each worker will length-match batches from 8 files and return any leftover/incomplete batches. When all of the complete batches are finished, the leftovers are combined to form batches. Thenum_workers = 0
case should be very close to the original implementation.In terms of performance, I saw a 2-4x increase in training speed (batches/sec) when training large transformer models w/ large batch sizes compared to the normal sampling methods.
A smaller change is renaming the "batch" labeled variables in
ParquetDataset
to "chunk" to clarify that they are not batches. In fact, I don't think any sort of batching should occur in the dataloaders.