Improved device specification

This commit is contained in:
piero
2019-11-27 16:39:49 -08:00
committed by Julien Chaumond
parent 4f2164e40e
commit 7ffe47c888

View File

@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
def train_discriminator(
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, use_cuda=False):
if use_cuda:
global device
device = 'cuda'
save_model=False, cached=False, no_cuda=False):
global device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print('Preprocessing {} dataset...'.format(dataset))
start = time.time()
@@ -577,8 +576,8 @@ if __name__ == '__main__':
help='whether to save the model')
parser.add_argument('--cached', action='store_true',
help='whether to cache the input representations')
parser.add_argument('--use_cuda', action='store_true',
help='use to turn on cuda')
parser.add_argument('--no_cuda', action='store_true',
help='use to turn off cuda')
args = parser.parse_args()
train_discriminator(**(vars(args)))