diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py index 5291ad4b51..fccfb14426 100644 --- a/examples/run_pplm_discrim_train.py +++ b/examples/run_pplm_discrim_train.py @@ -545,10 +545,11 @@ def train_discriminator( if save_model: # torch.save(discriminator.state_dict(), # "{}_discriminator_{}.pt".format( - # args.dataset, epoch + # args.dataset, epoch + 1 # )) torch.save(discriminator.get_classifier().state_dict(), - "{}_classifier_head_epoch_{}.pt".format(dataset, epoch)) + "{}_classifier_head_epoch_{}.pt".format(dataset, + epoch + 1)) if __name__ == "__main__":