load all models on cpu

This commit is contained in:
thomwolf
2019-04-15 15:43:01 +02:00
parent 2499b0a5fc
commit df5d9c3551
4 changed files with 4 additions and 4 deletions

View File

@@ -476,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)