Revert "[WIP] Hard error when ignoring tensors." (#28898)
Revert "[WIP] Hard error when ignoring tensors. (#27484)"
This reverts commit 2da28c4b41.
This commit is contained in:
@@ -257,26 +257,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_manually_shared_disjointed_tensors_optimum(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
# Let's fuse qkv
|
||||
attn = model.encoder.layer[0].attention.self
|
||||
q = attn.query.weight
|
||||
k = attn.key.weight
|
||||
v = attn.value.weight
|
||||
# Force some shared storage
|
||||
qkv = torch.stack([q, k, v], dim=0)
|
||||
attn.query.weight = torch.nn.Parameter(qkv[0])
|
||||
attn.key.weight = torch.nn.Parameter(qkv[1])
|
||||
attn.value.weight = torch.nn.Parameter(qkv[2])
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
Reference in New Issue
Block a user