From 75e1eed8d190afa5be30fba05cd872d79b492a24 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Mon, 22 Jun 2020 21:58:47 +0200 Subject: [PATCH] Cleaner warning when loading pretrained models (#4557) * Cleaner warning when loading pretrained models This make more explicit logging messages when using the various `from_pretrained` methods. It also make these messages as `logging.warning` because it's a common source of silent mistakes. * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond * style and quality Co-authored-by: Julien Chaumond --- src/transformers/modeling_tf_pytorch_utils.py | 55 ++++++++++++++++--- src/transformers/modeling_tf_utils.py | 27 +++++++-- src/transformers/modeling_utils.py | 29 +++++++--- 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index d8012068aa..5052695bc5 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -150,6 +150,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a tf_loaded_numel = 0 weight_value_tuples = [] all_pytorch_weights = set(list(pt_state_dict.keys())) + unexpected_keys = [] for symbolic_weight in symbolic_weights: sw_name = symbolic_weight.name name, transpose = convert_tf_weight_name_to_pt_weight_name( @@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a # Find associated numpy array in pytorch model state dict if name not in pt_state_dict: if allow_missing_keys: + unexpected_keys.append(name) continue raise AttributeError("{} not found in PyTorch model".format(name)) @@ -192,7 +194,31 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel)) - logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) + missing_keys = list(all_pytorch_weights) + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the PyTorch model were not used when " + f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)." + ) + else: + logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n" + f"If your task is similar to the task the model of the ckeckpoint was trained on, " + f"you can already use {tf_model.__class__.__name__} for predictions without further training." + ) return tf_model @@ -317,13 +343,28 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) missing_keys += missing_keys_pt - if len(missing_keys) > 0: - logger.info( - "Weights of {} not initialized from TF 2.0 model: {}".format(pt_model.__class__.__name__, missing_keys) - ) if len(unexpected_keys) > 0: - logger.info( - "Weights from TF 2.0 model not used in {}: {}".format(pt_model.__class__.__name__, unexpected_keys) + logger.warning( + f"Some weights of the TF 2.0 model were not used when " + f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)." + ) + else: + logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n" + f"If your task is similar to the task the model of the ckeckpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." ) logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights)) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e6a4a37a13..405a7f6555 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -504,13 +504,28 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] - if len(missing_keys) > 0: - logger.info( - "Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys) - ) if len(unexpected_keys) > 0: - logger.info( - "Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys) + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the ckeckpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bb9c53d7d2..e167abd089 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -750,17 +750,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) - if len(missing_keys) > 0: - logger.info( - "Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys - ) - ) if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: logger.info( - "Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys - ) + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the ckeckpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError(