Hard error when ignoring tensors. (#27484) (#29906)

* Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add small tests.

* Dead variable.

* Fixup.

* Fixing tied_Weights_keys on generic models.

* Fixup + T5 encoder/decoder tying (with different layers)

* Code quality.

* Dynamic member.

* trigger

* Fixing encoder name for other types of encoder/decoder combos.

* Fix scoping.

* Update .github/workflows/self-scheduled.yml

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fixing the tied_weights after the call.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2024-04-02 16:59:05 +02:00
committed by GitHub
parent 15cd68713d
commit 9b0a8ea7d1
7 changed files with 225 additions and 33 deletions

View File

@@ -101,7 +101,7 @@ if is_torch_available():
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import shard_checkpoint
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
@@ -256,6 +256,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_manually_shared_disjointed_tensors_optimum(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
# Let's fuse qkv
attn = model.encoder.layer[0].attention.self
q = attn.query.weight
k = attn.key.weight
v = attn.value.weight
# Force some shared storage
qkv = torch.stack([q, k, v], dim=0)
attn.query.weight = torch.nn.Parameter(qkv[0])
attn.key.weight = torch.nn.Parameter(qkv[1])
attn.value.weight = torch.nn.Parameter(qkv[2])
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
@@ -2222,3 +2242,40 @@ class Mask4DTestHard(unittest.TestCase):
]
self.assertEqual(decoded_0, decoded_1b)
@require_torch
class TestTensorSharing(TestCasePlus):
def test_disjoint(self):
main = torch.zeros(10)
a = main[:5]
b = main[5:]
state_dict = {"a": a, "b": b}
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [])
self.assertEqual(disjoint_names, ["a", "b"])
a = main[::2]
b = main[1::2]
state_dict = {"a": a, "b": b}
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [{"a", "b"}])
self.assertEqual(disjoint_names, [])
def test_identical(self):
a = torch.zeros(10)
b = a
state_dict = {"a": a, "b": b}
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [])
self.assertEqual(identical_names, [{"a", "b"}])
b = a[:5]
state_dict = {"a": a, "b": b}
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [{"a", "b"}])
self.assertEqual(identical_names, [])