Fix from_pretrained API with config and state_dict (#21542)
This commit is contained in:
@@ -2770,7 +2770,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
del state_dict[checkpoint_key]
|
del state_dict[checkpoint_key]
|
||||||
return mismatched_keys
|
return mismatched_keys
|
||||||
|
|
||||||
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
if resolved_archive_file is not None:
|
||||||
|
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
||||||
|
else:
|
||||||
|
folder = None
|
||||||
if device_map is not None and is_safetensors:
|
if device_map is not None and is_safetensors:
|
||||||
param_device_map = expand_device_map(device_map, original_loaded_keys)
|
param_device_map = expand_device_map(device_map, original_loaded_keys)
|
||||||
|
|
||||||
|
|||||||
@@ -2749,6 +2749,15 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
BertModel.from_pretrained(TINY_T5)
|
BertModel.from_pretrained(TINY_T5)
|
||||||
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_no_checkpoint(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
model = BertModel(config)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict)
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_model_from_config_torch_dtype(self):
|
def test_model_from_config_torch_dtype(self):
|
||||||
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||||
|
|||||||
Reference in New Issue
Block a user