Added a +1 to epoch when saving weights
This commit is contained in:
@@ -545,10 +545,11 @@ def train_discriminator(
|
|||||||
if save_model:
|
if save_model:
|
||||||
# torch.save(discriminator.state_dict(),
|
# torch.save(discriminator.state_dict(),
|
||||||
# "{}_discriminator_{}.pt".format(
|
# "{}_discriminator_{}.pt".format(
|
||||||
# args.dataset, epoch
|
# args.dataset, epoch + 1
|
||||||
# ))
|
# ))
|
||||||
torch.save(discriminator.get_classifier().state_dict(),
|
torch.save(discriminator.get_classifier().state_dict(),
|
||||||
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch))
|
"{}_classifier_head_epoch_{}.pt".format(dataset,
|
||||||
|
epoch + 1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user