Proposed fix for TF example now running on safetensors. (#23208)
* Proposed fix for TF example now running on safetensors. * Adding more warnings and returning keys. * Trigger CI * Trigger CI --------- Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -297,7 +297,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["bleu"], 30)
|
self.assertGreaterEqual(result["bleu"], 30)
|
||||||
|
|
||||||
@skip("Fix me Matt")
|
|
||||||
def test_run_image_classification(self):
|
def test_run_image_classification(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
output_loading_info=False,
|
output_loading_info=False,
|
||||||
_prefix=None,
|
_prefix=None,
|
||||||
tf_to_pt_weight_rename=None,
|
tf_to_pt_weight_rename=None,
|
||||||
|
ignore_mismatched_sizes=False,
|
||||||
):
|
):
|
||||||
"""Load a pytorch state_dict in a TF 2.0 model."""
|
"""Load a pytorch state_dict in a TF 2.0 model."""
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -297,6 +298,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
all_pytorch_weights = set(pt_state_dict.keys())
|
all_pytorch_weights = set(pt_state_dict.keys())
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
|
mismatched_keys = []
|
||||||
for symbolic_weight in symbolic_weights:
|
for symbolic_weight in symbolic_weights:
|
||||||
sw_name = symbolic_weight.name
|
sw_name = symbolic_weight.name
|
||||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||||
@@ -319,7 +321,18 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
continue
|
continue
|
||||||
raise AttributeError(f"{name} not found in PyTorch model")
|
raise AttributeError(f"{name} not found in PyTorch model")
|
||||||
|
|
||||||
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
try:
|
||||||
|
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
||||||
|
except tf.errors.InvalidArgumentError as e:
|
||||||
|
if not ignore_mismatched_sizes:
|
||||||
|
error_msg = str(e)
|
||||||
|
error_msg += (
|
||||||
|
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
||||||
|
)
|
||||||
|
raise tf.errors.InvalidArgumentError(error_msg)
|
||||||
|
else:
|
||||||
|
mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape))
|
||||||
|
continue
|
||||||
|
|
||||||
tf_loaded_numel += tensor_size(array)
|
tf_loaded_numel += tensor_size(array)
|
||||||
|
|
||||||
@@ -367,8 +380,26 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(mismatched_keys) > 0:
|
||||||
|
mismatched_warning = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||||
|
for key, shape1, shape2 in mismatched_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
|
||||||
|
f" are newly initialized because the shapes did not"
|
||||||
|
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||||
|
" to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"mismatched_keys": mismatched_keys,
|
||||||
|
}
|
||||||
return tf_model, loading_info
|
return tf_model, loading_info
|
||||||
|
|
||||||
return tf_model
|
return tf_model
|
||||||
|
|||||||
@@ -2820,6 +2820,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
allow_missing_keys=True,
|
allow_missing_keys=True,
|
||||||
output_loading_info=output_loading_info,
|
output_loading_info=output_loading_info,
|
||||||
_prefix=load_weight_prefix,
|
_prefix=load_weight_prefix,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
|
|||||||
Reference in New Issue
Block a user