* Hard error when ignoring tensors. (#27484) * [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add small tests. * Dead variable. * Fixup. * Fixing tied_Weights_keys on generic models. * Fixup + T5 encoder/decoder tying (with different layers) * Code quality. * Dynamic member. * trigger * Fixing encoder name for other types of encoder/decoder combos. * Fix scoping. * Update .github/workflows/self-scheduled.yml Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fixing the tied_weights after the call. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -101,7 +101,7 @@ if is_torch_available():
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@@ -256,6 +256,26 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_manually_shared_disjointed_tensors_optimum(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
# Let's fuse qkv
|
||||
attn = model.encoder.layer[0].attention.self
|
||||
q = attn.query.weight
|
||||
k = attn.key.weight
|
||||
v = attn.value.weight
|
||||
# Force some shared storage
|
||||
qkv = torch.stack([q, k, v], dim=0)
|
||||
attn.query.weight = torch.nn.Parameter(qkv[0])
|
||||
attn.key.weight = torch.nn.Parameter(qkv[1])
|
||||
attn.value.weight = torch.nn.Parameter(qkv[2])
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
@@ -2222,3 +2242,40 @@ class Mask4DTestHard(unittest.TestCase):
|
||||
]
|
||||
|
||||
self.assertEqual(decoded_0, decoded_1b)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestTensorSharing(TestCasePlus):
|
||||
def test_disjoint(self):
|
||||
main = torch.zeros(10)
|
||||
a = main[:5]
|
||||
b = main[5:]
|
||||
state_dict = {"a": a, "b": b}
|
||||
|
||||
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
|
||||
self.assertEqual(shared_names, [])
|
||||
self.assertEqual(disjoint_names, ["a", "b"])
|
||||
|
||||
a = main[::2]
|
||||
b = main[1::2]
|
||||
state_dict = {"a": a, "b": b}
|
||||
|
||||
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
|
||||
self.assertEqual(shared_names, [{"a", "b"}])
|
||||
self.assertEqual(disjoint_names, [])
|
||||
|
||||
def test_identical(self):
|
||||
a = torch.zeros(10)
|
||||
b = a
|
||||
state_dict = {"a": a, "b": b}
|
||||
|
||||
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
||||
self.assertEqual(shared_names, [])
|
||||
self.assertEqual(identical_names, [{"a", "b"}])
|
||||
|
||||
b = a[:5]
|
||||
state_dict = {"a": a, "b": b}
|
||||
|
||||
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
||||
self.assertEqual(shared_names, [{"a", "b"}])
|
||||
self.assertEqual(identical_names, [])
|
||||
|
||||
Reference in New Issue
Block a user