A Pytorch implementation of InfoGAIL built on top of stable-baselines3 and imiation.
Core changes to the imitation repository v0.2.0 are done to implement InfoGAIL We have kept only necessary files from the imitation repository.
Two new classes in src\imitation\rewards\discrim_nets.py
WassersteinDiscrimNet
: InheritsDiscrimNet
and overwritesdisc_loss
that implements the Wasserstein loss to train the discriminatorDiscrimNetWGAIL
: InheritsWassersteinDiscrimNet
and overwritesreward_train
with -logits as the reward for the generator.
Two new classes in src\imitation\algorithms\adversarial.py
WGAIL
: Core changes fromGAIL
class areDiscrimNetWGAIL
as the discriminator anddisc_opt_cls
as RMSprop instead of AdamWassersteinAdversarialTrainer
: inheritsAdversarialTrainer
class to include gradient clipping in thetrain_disc
function
Sample test script for WGAIL: python .\minigrid_wgail_training_script.py -r testing_wgail -t minigrid_empty_right_down -f --vis-trained
Policy was consistent even if env was changed from "MiniGrid-Empty-6x6-v0" to "MiniGrid-Empty-8x8-v0" and "MiniGrid-Empty-5x5-v0" while testing
To avoid any more core changes to the imitation library, all classes needed to execute a CNN version of GAIL and WGAIL are saved in the cnn_modules
folder.
Two new discriminator classes in cnn_modules/cnn_discriminator.py
ActObsCNN
: uses a NaturCNN backbone from stable-baselines 3 to extract features from an image observation. Obs features are concatenated with the action and rest is asActObsMLP
would work.ObsOnlyCNN
: same asActObsCNN
, no action is used.
To use the CNN version of GAIL or WGAIL, exclude the -f
arg.
A sample test script for CNN-GAIL: python .\minigrid_gail_training_script.py -r testing_cnngail -t img_no_stack_minigrid_empty_down_right --vis-trained