From df5d9c3551a6405feb697a1cad903dddffa04bfe Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 15:43:01 +0200 Subject: [PATCH] load all models on cpu --- pytorch_pretrained_bert/modeling.py | 2 +- pytorch_pretrained_bert/modeling_gpt2.py | 2 +- pytorch_pretrained_bert/modeling_openai.py | 2 +- pytorch_pretrained_bert/modeling_transfo_xl.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index dca6ac53f2..8dfb5fe51e 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -594,7 +594,7 @@ class BertPreTrainedModel(nn.Module): model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(weights_path, map_location='cpu') if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index e6017d33e4..7cf1e6b59d 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -418,7 +418,7 @@ class GPT2PreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) return load_tf_weights_in_gpt2(model, resolved_archive_file) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 57a7921d7a..3dedc53f11 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -476,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) return load_tf_weights_in_openai_gpt(model, resolved_archive_file) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 0b732cdef1..e8fffc5b60 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -944,7 +944,7 @@ class TransfoXLPreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) + state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)