[WIP] Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2024-02-05 09:17:24 +01:00
committed by GitHub
parent 0466fd5ca2
commit 2da28c4b41
2 changed files with 112 additions and 14 deletions

View File

@@ -257,6 +257,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_manually_shared_disjointed_tensors_optimum(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
# Let's fuse qkv
attn = model.encoder.layer[0].attention.self
q = attn.query.weight
k = attn.key.weight
v = attn.value.weight
# Force some shared storage
qkv = torch.stack([q, k, v], dim=0)
attn.query.weight = torch.nn.Parameter(qkv[0])
attn.key.weight = torch.nn.Parameter(qkv[1])
attn.value.weight = torch.nn.Parameter(qkv[2])
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)