adding option to load on cpu
This commit is contained in:
@@ -580,7 +580,7 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path)
|
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
if tempdir:
|
if tempdir:
|
||||||
# Clean up temp dir
|
# Clean up temp dir
|
||||||
shutil.rmtree(tempdir)
|
shutil.rmtree(tempdir)
|
||||||
|
|||||||
Reference in New Issue
Block a user