Improved device specification
This commit is contained in:
@@ -242,10 +242,9 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
|||||||
def train_discriminator(
|
def train_discriminator(
|
||||||
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
|
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
|
||||||
epochs=10, batch_size=64, log_interval=10,
|
epochs=10, batch_size=64, log_interval=10,
|
||||||
save_model=False, cached=False, use_cuda=False):
|
save_model=False, cached=False, no_cuda=False):
|
||||||
if use_cuda:
|
global device
|
||||||
global device
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||||
device = 'cuda'
|
|
||||||
|
|
||||||
print('Preprocessing {} dataset...'.format(dataset))
|
print('Preprocessing {} dataset...'.format(dataset))
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -577,8 +576,8 @@ if __name__ == '__main__':
|
|||||||
help='whether to save the model')
|
help='whether to save the model')
|
||||||
parser.add_argument('--cached', action='store_true',
|
parser.add_argument('--cached', action='store_true',
|
||||||
help='whether to cache the input representations')
|
help='whether to cache the input representations')
|
||||||
parser.add_argument('--use_cuda', action='store_true',
|
parser.add_argument('--no_cuda', action='store_true',
|
||||||
help='use to turn on cuda')
|
help='use to turn off cuda')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
train_discriminator(**(vars(args)))
|
train_discriminator(**(vars(args)))
|
||||||
|
|||||||
Reference in New Issue
Block a user