Improved device specification
This commit is contained in:
@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
||||
def train_discriminator(
|
||||
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
|
||||
epochs=10, batch_size=64, log_interval=10,
|
||||
save_model=False, cached=False, use_cuda=False):
|
||||
if use_cuda:
|
||||
global device
|
||||
device = 'cuda'
|
||||
save_model=False, cached=False, no_cuda=False):
|
||||
global device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print('Preprocessing {} dataset...'.format(dataset))
|
||||
start = time.time()
|
||||
@@ -577,8 +576,8 @@ if __name__ == '__main__':
|
||||
help='whether to save the model')
|
||||
parser.add_argument('--cached', action='store_true',
|
||||
help='whether to cache the input representations')
|
||||
parser.add_argument('--use_cuda', action='store_true',
|
||||
help='use to turn on cuda')
|
||||
parser.add_argument('--no_cuda', action='store_true',
|
||||
help='use to turn off cuda')
|
||||
args = parser.parse_args()
|
||||
|
||||
train_discriminator(**(vars(args)))
|
||||
|
||||
Reference in New Issue
Block a user