Tied weights load (#24310)
* Use tied weight keys * More * Fix tied weight missing warning * Only give info on unexpected keys with different classes * Deal with empty archs * Fix tests * Refine test
This commit is contained in:
@@ -1779,10 +1779,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
for names in shared_ptrs.values():
|
for names in shared_ptrs.values():
|
||||||
# Removing the keys which are declared as known duplicates on
|
# Removing the keys which are declared as known duplicates on
|
||||||
# load. This allows to make sure the name which is kept is consistent.
|
# load. This allows to make sure the name which is kept is consistent.
|
||||||
if self._keys_to_ignore_on_load_missing is not None:
|
if self._tied_weights_keys is not None:
|
||||||
found = 0
|
found = 0
|
||||||
for name in sorted(names):
|
for name in sorted(names):
|
||||||
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
|
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys)
|
||||||
if matches_pattern and name in state_dict:
|
if matches_pattern and name in state_dict:
|
||||||
found += 1
|
found += 1
|
||||||
if found < len(names):
|
if found < len(names):
|
||||||
@@ -3020,22 +3020,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
|
model.tie_weights()
|
||||||
tied_params = find_tied_parameters(model)
|
tied_params = find_tied_parameters(model)
|
||||||
else:
|
else:
|
||||||
tied_params = []
|
tied_params = []
|
||||||
_missing = []
|
|
||||||
for k in missing_keys:
|
for group in tied_params:
|
||||||
found = False
|
missing_in_group = [k for k in missing_keys if k in group]
|
||||||
for group in tied_params:
|
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
|
||||||
if k in group:
|
missing_keys = [k for k in missing_keys if k not in missing_in_group]
|
||||||
found = True
|
|
||||||
if len(group) > 2:
|
|
||||||
group.remove(k)
|
|
||||||
else:
|
|
||||||
_missing.append(k)
|
|
||||||
if not found:
|
|
||||||
_missing.append(k)
|
|
||||||
missing_keys = _missing
|
|
||||||
|
|
||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
# the user.
|
# the user.
|
||||||
@@ -3275,7 +3268,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
|
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.warning(
|
archs = [] if model.config.architectures is None else model.config.architectures
|
||||||
|
warner = logger.warn if model.__class__.__name__ in archs else logger.info
|
||||||
|
warner(
|
||||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||||
|
|||||||
@@ -82,16 +82,31 @@ if is_torch_available():
|
|||||||
|
|
||||||
# Fake pretrained models for tests
|
# Fake pretrained models for tests
|
||||||
class BaseModel(PreTrainedModel):
|
class BaseModel(PreTrainedModel):
|
||||||
|
base_model_prefix = "base"
|
||||||
config_class = PretrainedConfig
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.linear = nn.Linear(4, 5)
|
self.linear = nn.Linear(5, 5)
|
||||||
self.linear_2 = nn.Linear(5, 6)
|
self.linear_2 = nn.Linear(5, 5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear_2(self.linear(x))
|
return self.linear_2(self.linear(x))
|
||||||
|
|
||||||
|
class BaseModelWithTiedWeights(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.linear = nn.Linear(5, 5)
|
||||||
|
self.linear_2 = nn.Linear(5, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_2(self.linear(x))
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
self.linear_2.weight = self.linear.weight
|
||||||
|
|
||||||
class ModelWithHead(PreTrainedModel):
|
class ModelWithHead(PreTrainedModel):
|
||||||
base_model_prefix = "base"
|
base_model_prefix = "base"
|
||||||
config_class = PretrainedConfig
|
config_class = PretrainedConfig
|
||||||
@@ -103,12 +118,30 @@ if is_torch_available():
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.base = BaseModel(config)
|
self.base = BaseModel(config)
|
||||||
# linear is a common name between Base and Head on purpose.
|
# linear is a common name between Base and Head on purpose.
|
||||||
self.linear = nn.Linear(6, 3)
|
self.linear = nn.Linear(5, 5)
|
||||||
self.linear2 = nn.Linear(3, 5)
|
self.linear2 = nn.Linear(5, 5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear2(self.linear(self.base(x)))
|
return self.linear2(self.linear(self.base(x)))
|
||||||
|
|
||||||
|
class ModelWithHeadAndTiedWeights(PreTrainedModel):
|
||||||
|
base_model_prefix = "base"
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base = BaseModel(config)
|
||||||
|
self.decoder = nn.Linear(5, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(self.base(x))
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
self.decoder.weight = self.base.linear.weight
|
||||||
|
|
||||||
|
|
||||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
@@ -857,6 +890,54 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
):
|
):
|
||||||
_ = ModelWithHead.from_pretrained(tmp_dir)
|
_ = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
def test_tied_weights_reload(self):
|
||||||
|
# Base
|
||||||
|
model = BaseModelWithTiedWeights(PretrainedConfig())
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir)
|
||||||
|
self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
# Remove tied weight from state_dict -> model should load with no complain of missing keys
|
||||||
|
del state_dict["linear_2.weight"]
|
||||||
|
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||||
|
new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
|
self.assertListEqual(load_info["missing_keys"], [])
|
||||||
|
self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
|
||||||
|
|
||||||
|
# With head
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
|
self.assertIs(new_model.base.linear.weight, new_model.decoder.weight)
|
||||||
|
# Should only complain about the missing bias
|
||||||
|
self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
|
||||||
|
|
||||||
|
def test_unexpected_keys_warnings(self):
|
||||||
|
model = ModelWithHead(PretrainedConfig())
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Loading the model with a new class, we don't get a warning for unexpected weights, just an info
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
|
self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
|
||||||
|
self.assertEqual(
|
||||||
|
set(loading_info["unexpected_keys"]),
|
||||||
|
{"linear.weight", "linear.bias", "linear2.weight", "linear2.bias"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loading the model with the same class, we do get a warning for unexpected weights
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict["added_key"] = state_dict["linear.weight"]
|
||||||
|
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
|
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
||||||
|
self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
def test_pretrained_low_mem_new_config(self):
|
def test_pretrained_low_mem_new_config(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user