Cleaner warning when loading pretrained models (#4557)

* Cleaner warning when loading pretrained models

This make more explicit logging messages when using the various `from_pretrained` methods. It also make these messages as `logging.warning` because it's a common source of silent mistakes.

* Update src/transformers/modeling_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* style and quality

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Thomas Wolf
2020-06-22 21:58:47 +02:00
committed by GitHub
parent 4e741efa92
commit 75e1eed8d1
3 changed files with 89 additions and 22 deletions

View File

@@ -750,17 +750,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys
)
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
"Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys
)
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(