Refactor conversion function (#19799)
* Refactor conversion function * Remove dupe line * Fixes * Fixes * Use the right variable... * Fix last test
This commit is contained in:
@@ -21,7 +21,8 @@ import re
|
||||
|
||||
import numpy
|
||||
|
||||
from .utils import ExplicitEnum, logging
|
||||
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze
|
||||
from .utils import transpose as transpose_func
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -66,10 +67,12 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
||||
if len(tf_name) > 1:
|
||||
tf_name = tf_name[1:] # Remove level zero
|
||||
|
||||
tf_weight_shape = list(tf_weight_shape)
|
||||
|
||||
# When should we transpose the weights
|
||||
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
|
||||
if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
|
||||
transpose = TransposeType.CONV2D
|
||||
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
|
||||
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
|
||||
transpose = TransposeType.CONV1D
|
||||
elif bool(
|
||||
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
|
||||
@@ -98,6 +101,43 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
||||
return tf_name, transpose
|
||||
|
||||
|
||||
def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
|
||||
"""
|
||||
Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
|
||||
framework agnostic way.
|
||||
"""
|
||||
if transpose is TransposeType.CONV2D:
|
||||
# Conv2D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
|
||||
weight = transpose_func(weight, axes=axes)
|
||||
elif transpose is TransposeType.CONV1D:
|
||||
# Conv1D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel)
|
||||
# -> TF: (kernel, num_in_channel, num_out_channel)
|
||||
weight = transpose_func(weight, axes=(2, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
weight = transpose_func(weight)
|
||||
|
||||
if match_shape is None:
|
||||
return weight
|
||||
|
||||
if len(match_shape) < len(weight.shape):
|
||||
weight = squeeze(weight)
|
||||
elif len(match_shape) > len(weight.shape):
|
||||
weight = expand_dims(weight, axis=0)
|
||||
|
||||
if list(match_shape) != list(weight.shape):
|
||||
try:
|
||||
weight = reshape(weight, match_shape)
|
||||
except AssertionError as e:
|
||||
e.args += (match_shape, match_shape)
|
||||
raise e
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
#####################
|
||||
# PyTorch => TF 2.0 #
|
||||
#####################
|
||||
@@ -155,7 +195,6 @@ def load_pytorch_weights_in_tf2_model(
|
||||
try:
|
||||
import tensorflow as tf # noqa: F401
|
||||
import torch # noqa: F401
|
||||
from tensorflow.python.keras import backend as K
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||
@@ -163,6 +202,22 @@ def load_pytorch_weights_in_tf2_model(
|
||||
)
|
||||
raise
|
||||
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
tf_model,
|
||||
pt_state_dict,
|
||||
tf_inputs=tf_inputs,
|
||||
allow_missing_keys=allow_missing_keys,
|
||||
output_loading_info=output_loading_info,
|
||||
)
|
||||
|
||||
|
||||
def load_pytorch_state_dict_in_tf2_model(
|
||||
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
|
||||
):
|
||||
"""Load a pytorch state_dict in a TF 2.0 model."""
|
||||
from tensorflow.python.keras import backend as K
|
||||
|
||||
if tf_inputs is None:
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
|
||||
@@ -216,41 +271,9 @@ def load_pytorch_weights_in_tf2_model(
|
||||
continue
|
||||
raise AttributeError(f"{name} not found in PyTorch model")
|
||||
|
||||
array = pt_state_dict[name].numpy()
|
||||
|
||||
if transpose is TransposeType.CONV2D:
|
||||
# Conv2D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
array = numpy.transpose(array, axes=(2, 3, 1, 0))
|
||||
elif transpose is TransposeType.CONV1D:
|
||||
# Conv1D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel)
|
||||
# -> TF: (kernel, num_in_channel, num_out_channel)
|
||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
if len(symbolic_weight.shape) < len(array.shape):
|
||||
array = numpy.squeeze(array)
|
||||
elif len(symbolic_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
if list(symbolic_weight.shape) != list(array.shape):
|
||||
try:
|
||||
array = numpy.reshape(array, symbolic_weight.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (symbolic_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
try:
|
||||
assert list(symbolic_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (symbolic_weight.shape, array.shape)
|
||||
raise e
|
||||
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
||||
|
||||
tf_loaded_numel += array.size
|
||||
# logger.warning(f"Initialize TF weight {symbolic_weight.name}")
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
all_pytorch_weights.discard(name)
|
||||
@@ -370,6 +393,15 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
)
|
||||
raise
|
||||
|
||||
tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
|
||||
return load_tf2_state_dict_in_pytorch_model(
|
||||
pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
|
||||
)
|
||||
|
||||
|
||||
def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
|
||||
import torch
|
||||
|
||||
new_pt_params_dict = {}
|
||||
current_pt_params_dict = dict(pt_model.named_parameters())
|
||||
|
||||
@@ -381,11 +413,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
|
||||
# Build a map from potential PyTorch weight names to TF 2.0 Variables
|
||||
tf_weights_map = {}
|
||||
for tf_weight in tf_weights:
|
||||
for name, tf_weight in tf_state_dict.items():
|
||||
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
|
||||
name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
|
||||
)
|
||||
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
|
||||
tf_weights_map[pt_name] = (tf_weight, transpose)
|
||||
|
||||
all_tf_weights = set(list(tf_weights_map.keys()))
|
||||
loaded_pt_weights_data_ptr = {}
|
||||
@@ -406,43 +438,18 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
|
||||
array, transpose = tf_weights_map[pt_weight_name]
|
||||
|
||||
if transpose is TransposeType.CONV2D:
|
||||
# Conv2D weight:
|
||||
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
array = numpy.transpose(array, axes=(3, 2, 0, 1))
|
||||
elif transpose is TransposeType.CONV1D:
|
||||
# Conv1D weight:
|
||||
# TF: (kernel, num_in_channel, num_out_channel)
|
||||
# -> PT: (num_out_channel, num_in_channel, kernel)
|
||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
|
||||
|
||||
if len(pt_weight.shape) < len(array.shape):
|
||||
array = numpy.squeeze(array)
|
||||
elif len(pt_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
if list(pt_weight.shape) != list(array.shape):
|
||||
try:
|
||||
array = numpy.reshape(array, pt_weight.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (pt_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
try:
|
||||
assert list(pt_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (pt_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
# logger.warning(f"Initialize PyTorch weight {pt_weight_name}")
|
||||
# Make sure we have a proper numpy array
|
||||
if numpy.isscalar(array):
|
||||
array = numpy.array(array)
|
||||
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
|
||||
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
|
||||
if not is_torch_tensor(array) and not is_numpy_array(array):
|
||||
array = array.numpy()
|
||||
if is_numpy_array(array):
|
||||
# Convert to torch tensor
|
||||
array = torch.from_numpy(array)
|
||||
|
||||
new_pt_params_dict[pt_weight_name] = array
|
||||
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
|
||||
all_tf_weights.discard(pt_weight_name)
|
||||
|
||||
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||
|
||||
@@ -38,6 +38,7 @@ from .generic import (
|
||||
PaddingStrategy,
|
||||
TensorType,
|
||||
cached_property,
|
||||
expand_dims,
|
||||
find_labels,
|
||||
flatten_dict,
|
||||
is_jax_tensor,
|
||||
@@ -46,8 +47,11 @@ from .generic import (
|
||||
is_tf_tensor,
|
||||
is_torch_device,
|
||||
is_torch_tensor,
|
||||
reshape,
|
||||
squeeze,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
transpose,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
from .hub import (
|
||||
|
||||
@@ -29,6 +29,13 @@ import numpy as np
|
||||
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class cached_property(property):
|
||||
"""
|
||||
Descriptor that mimics @property but caches output in member variable.
|
||||
@@ -370,3 +377,71 @@ def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
|
||||
yield tmp_dir
|
||||
else:
|
||||
yield working_dir
|
||||
|
||||
|
||||
def transpose(array, axes=None):
|
||||
"""
|
||||
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||
arrays.
|
||||
"""
|
||||
if is_numpy_array(array):
|
||||
return np.transpose(array, axes=axes)
|
||||
elif is_torch_tensor(array):
|
||||
return array.T if axes is None else array.permute(*axes)
|
||||
elif is_tf_tensor(array):
|
||||
return tf.transpose(array, perm=axes)
|
||||
elif is_jax_tensor(array):
|
||||
return jnp.transpose(array, axes=axes)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for transpose: {type(array)}.")
|
||||
|
||||
|
||||
def reshape(array, newshape):
|
||||
"""
|
||||
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||
arrays.
|
||||
"""
|
||||
if is_numpy_array(array):
|
||||
return np.reshape(array, newshape)
|
||||
elif is_torch_tensor(array):
|
||||
return array.reshape(*newshape)
|
||||
elif is_tf_tensor(array):
|
||||
return tf.reshape(array, newshape)
|
||||
elif is_jax_tensor(array):
|
||||
return jnp.reshape(array, newshape)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for reshape: {type(array)}.")
|
||||
|
||||
|
||||
def squeeze(array, axis=None):
|
||||
"""
|
||||
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||
arrays.
|
||||
"""
|
||||
if is_numpy_array(array):
|
||||
return np.squeeze(array, axis=axis)
|
||||
elif is_torch_tensor(array):
|
||||
return array.squeeze() if axis is None else array.squeeze(dim=axis)
|
||||
elif is_tf_tensor(array):
|
||||
return tf.squeeze(array, axis=axis)
|
||||
elif is_jax_tensor(array):
|
||||
return jnp.squeeze(array, axis=axis)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
|
||||
|
||||
|
||||
def expand_dims(array, axis):
|
||||
"""
|
||||
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||
arrays.
|
||||
"""
|
||||
if is_numpy_array(array):
|
||||
return np.expand_dims(array, axis)
|
||||
elif is_torch_tensor(array):
|
||||
return array.unsqueeze(dim=axis)
|
||||
elif is_tf_tensor(array):
|
||||
return tf.expand_dims(array, axis=axis)
|
||||
elif is_jax_tensor(array):
|
||||
return jnp.expand_dims(array, axis=axis)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||||
|
||||
@@ -15,7 +15,29 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.utils import flatten_dict
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_flax, require_tf, require_torch
|
||||
from transformers.utils import (
|
||||
expand_dims,
|
||||
flatten_dict,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
reshape,
|
||||
squeeze,
|
||||
transpose,
|
||||
)
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class GenericTester(unittest.TestCase):
|
||||
@@ -43,3 +65,136 @@ class GenericTester(unittest.TestCase):
|
||||
}
|
||||
|
||||
self.assertEqual(flatten_dict(input_dict), expected_dict)
|
||||
|
||||
def test_transpose_numpy(self):
|
||||
x = np.random.randn(3, 4)
|
||||
self.assertTrue(np.allclose(transpose(x), x.transpose()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
|
||||
|
||||
@require_torch
|
||||
def test_transpose_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
|
||||
|
||||
@require_tf
|
||||
def test_transpose_tf(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
|
||||
|
||||
@require_flax
|
||||
def test_transpose_flax(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(transpose(x), np.asarray(transpose(t))))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), np.asarray(transpose(t, axes=(1, 2, 0)))))
|
||||
|
||||
def test_reshape_numpy(self):
|
||||
x = np.random.randn(3, 4)
|
||||
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.reshape(x, (4, 3))))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.reshape(x, (12, 5))))
|
||||
|
||||
@require_torch
|
||||
def test_reshape_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
|
||||
|
||||
@require_tf
|
||||
def test_reshape_tf(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
|
||||
|
||||
@require_flax
|
||||
def test_reshape_flax(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.asarray(reshape(t, (4, 3)))))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.asarray(reshape(t, (12, 5)))))
|
||||
|
||||
def test_squeeze_numpy(self):
|
||||
x = np.random.randn(1, 3, 4)
|
||||
self.assertTrue(np.allclose(squeeze(x), np.squeeze(x)))
|
||||
|
||||
x = np.random.randn(1, 4, 1, 5)
|
||||
self.assertTrue(np.allclose(squeeze(x, axis=2), np.squeeze(x, axis=2)))
|
||||
|
||||
@require_torch
|
||||
def test_squeeze_torch(self):
|
||||
x = np.random.randn(1, 3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
|
||||
|
||||
x = np.random.randn(1, 4, 1, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
|
||||
|
||||
@require_tf
|
||||
def test_squeeze_tf(self):
|
||||
x = np.random.randn(1, 3, 4)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
|
||||
|
||||
x = np.random.randn(1, 4, 1, 5)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
|
||||
|
||||
@require_flax
|
||||
def test_squeeze_flax(self):
|
||||
x = np.random.randn(1, 3, 4)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(squeeze(x), np.asarray(squeeze(t))))
|
||||
|
||||
x = np.random.randn(1, 4, 1, 5)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(squeeze(x, axis=2), np.asarray(squeeze(t, axis=2))))
|
||||
|
||||
def test_expand_dims_numpy(self):
|
||||
x = np.random.randn(3, 4)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
|
||||
|
||||
@require_torch
|
||||
def test_expand_dims_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
|
||||
|
||||
@require_tf
|
||||
def test_expand_dims_tf(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = tf.constant(x)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
|
||||
|
||||
@require_flax
|
||||
def test_expand_dims_flax(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
||||
|
||||
Reference in New Issue
Block a user