Cleaner warning when loading pretrained models (#4557)
* Cleaner warning when loading pretrained models This make more explicit logging messages when using the various `from_pretrained` methods. It also make these messages as `logging.warning` because it's a common source of silent mistakes. * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * style and quality Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -750,17 +750,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys
|
||||
)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys
|
||||
)
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user