diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py index 9438cbbac1..519e2de29a 100644 --- a/examples/run_pplm_discrim_train.py +++ b/examples/run_pplm_discrim_train.py @@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): def train_discriminator( dataset, dataset_fp=None, pretrained_model='gpt2-medium', epochs=10, batch_size=64, log_interval=10, - save_model=False, cached=False, use_cuda=False): - if use_cuda: - global device - device = 'cuda' + save_model=False, cached=False, no_cuda=False): + global device + device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" print('Preprocessing {} dataset...'.format(dataset)) start = time.time() @@ -577,8 +576,8 @@ if __name__ == '__main__': help='whether to save the model') parser.add_argument('--cached', action='store_true', help='whether to cache the input representations') - parser.add_argument('--use_cuda', action='store_true', - help='use to turn on cuda') + parser.add_argument('--no_cuda', action='store_true', + help='use to turn off cuda') args = parser.parse_args() train_discriminator(**(vars(args)))