I've fine-tuned the GPT2 model on South Park episodes. The library that was used to train the model was the huggingface transformers library and the dataset was gathered from this Kaggle dataset.
The Repo contains 4 modules:
- Data Preporcessing: Contains the code that preprocesses the dataset and creates the
SouthPark_Data_test.pkl
andSouthPark_Data_train.pkl
files. - Train: This module contains the code that trains the model.
- Testing: Computes the Rouge-1, Rouge-2, and Rouge-L scores for the test set.
- Inference: This module is used to generate episodes.
The models can be found here. Once downloaded, the desired model folders have to be put in the /AI folder. When calling the get_model function, the name should be specified in the checkpoint parameter.