Generate osu! standard beatmap object coordinates using a diffusion model with a transformer backbone.
This project is sort of a successor to Mapperator as both can turn a partial beatmap which has just the rhythm and spacing into a fully playable beatmap. With this, the task of automatically generating a beatmap from mp3 can be split up into two parts, one which generates rhythm and spacing from audio, and this part which turns it into a fully playable beatmap.
The second purpose style-transfer between beatmaps. Because you can extract the required features from existing beatmaps, its possible to have this AI remap any beatmap into another style.
The third purpose was to prove that deep learning AI is capable of modelling more complex geometric relations in beatmaps. Hit object relations come in such a large variety that it is almost impossible to model this algorithmically. In that sense, it provides similar challenges to image recognition and for that deep learning approaches have been shown to be very effective.
You need PyTorch with CUDA in order to use the GPU to train models. Training will not work without GPU. Sampling can be done without GPU but would be significantly faster with GPU.
Versions I used:
- Python 3.10
- PyTorch 2.1.0
- CUDA 11.8
Other dependencies:
- slider
- numpy
- matplotlib
- pandas
The easiest way to create your own beatmaps is to use our colab notebook.
sample.py
lets you generate new beatmaps using the rhythm and spacing from an existing beatmap.
You can also provide a specific style to map in by providing the beatmap ID of a map in the training data.
- Get a trained checkpoints from here or from training your own model.
- Generate
beatmap_idx.pickle
for your dataset by runninggenerate_beatmap_idx.py
in the testing folder. You need to edit the path to your dataset in the script before running. If you downloaded the checkpoint from here, you can use thebeatmap_idx.pickle
that is already present in the repository. - Run
sample.py
:
python sample.py --beatmap "path to beatmap" --ckpt "DiT-B-00-0700000.pt"
Important arguments:
--beatmap
The beatmap to take rhythm and spacing from.--ckpt
The training checkpoint to use.--model
The model corresponding to the checkpoint.--num-classes
The number of beatmaps in the dataset used to train the model.--beatmap_idx
Path to thebeatmap_idx.pickle
file specific to your dataset.--num-sampling-steps
The number of diffusion steps. Should be between 1 and 1000.--cfg-scale
Scalar for classifier-free guidance. Amplifies the effect of the style transfer. 1.0 for normal style.--style-id
The beatmap ID of the beatmap in the training data to use the style from.
This project is still in development, so things are likely to change.
Use the Mapperator.ConsoleApp
in Mapperator to generate a dataset from your osu! folder.
Just grab the latest release and run it using .NET 6.
Mapperator.ConsoleApp.exe dataset -m Standard -s Ranked -i 200000 -o "path to output folder"
This command generates a dataset with all ranked osu! standard gamemode beatmaps in your folder whose ID is at least 200k. There are several ways to filter the beatmaps to put in the dataset, so use the help command to figure out the arguments.
Alternatively you can download the dataset here.
You train the model using train.py
. It has several arguments that control the model, data, and hyperparameters.
Important arguments:
--data-path
The path to your dataset.--num-classes
The number of beatmaps in your dataset.--data-start
The start index for range of mapsets in the dataset.--data-end
The end index for range of mapsets in the dataset. Not inclusive.--model
The model to train. There are 4 models with increasing sizes.--global-batch-size
The combined batch size over all GPUs.--num-workers
The number of parallel data loading processes per GPU.--ckpt-every
The number of training steps between checkpoints.--seq-len
The length of subsequences of the beatmap for training examples. Determines the context size.--stride
The distance between windows during data loading. Bigger stride means smaller epochs.--ckpt
Path to a checkpoint file to resume training from.--dist
The distribution strategy to use.gloo
works for Windows.--lr
The learning rate.--relearn-embeds
Forget the learnt embeddings in theckpt
and learn a new embedding table. Important if you train on a different dataset from what your checkpoint was trained on.
--nproc-per-node
determines the number of GPUs to use.
torchrun --nproc-per-node=1 train.py --data-path "..\all_ranked_sets_ever" --model DiT-B --num-workers 4 --epochs 100 --global-batch-size 256 --ckpt-every 20000 --seq-len 128 --dist gloo
The diffusion model takes a sequence of data points where each data point represents a single thing in a beatmap that has a coordinate. These could for example be a circle, slider head, spinner start, red anchor, catmull anchor, slider end with 2 repeats, etc. Together these data points describe all the hit objects.
A data point contains the following information:
- The time of the datapoint.
- The distance to the previous datapoint.
- The type of the datapoint.
There are the following types:
- is circle
- is circle NC
- is spinner
- is spinner end
- is sliderhead
- is sliderhead NC
- is bezier anchor
- is perfect anchor
- is catmull anchor
- is red anchor
- is last anchor
- is slider end 0 repeat
- is slider end 1 repeat
- is slider end 2 repeats
- is slider end even repeats
- is slider end uneven repeats
The output of the diffusion model is a sequence which gives the X and Y coordinates for each data point.
I copied the model architecture from Scalable Diffusion Models with Transformers and modified it for my purpose.
The main changes are:
- Remove the patchify layers so IO concerns only sequences instead of images.
- Remove the auto-encoder so this is not a latent diffusion model.
- Remove positional embedding and embed the data point time instead of position in sequence.
- Add inputs for the additional information of the data point.
- Add attention masking so sequence length can be extended during sampling without too much loss in quality.
Training beatmaps are converted to sequences of data points and then split into overlapping windows with a fixed number of data points. The time values in a window get a random offset, so the absolute position of a window is unknown while the relative timing stays intact. Also data augmentation is used to flip the positions of data points horizontally or vertically.
The beatmap ID of the beatmap where the window came from is provided as a class label. This causes the model to learn embeddings of each beatmap in the training data which are somehow descriptive about how objects are placed.
While training is done on windows of 128 data points, to sample entire beatmaps you generally need more than 128 data points. Luckily self-attention allows us to freely change the sequence length after training. To help with this, we added an attention mask which limits attention to only look at a small neighbourhood near each data point instead of the whole sequence.