Refactor conversion function (#19799)
* Refactor conversion function * Remove dupe line * Fixes * Fixes * Use the right variable... * Fix last test
This commit is contained in:
@@ -21,7 +21,8 @@ import re
|
|||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from .utils import ExplicitEnum, logging
|
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze
|
||||||
|
from .utils import transpose as transpose_func
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -66,10 +67,12 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
|||||||
if len(tf_name) > 1:
|
if len(tf_name) > 1:
|
||||||
tf_name = tf_name[1:] # Remove level zero
|
tf_name = tf_name[1:] # Remove level zero
|
||||||
|
|
||||||
|
tf_weight_shape = list(tf_weight_shape)
|
||||||
|
|
||||||
# When should we transpose the weights
|
# When should we transpose the weights
|
||||||
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
|
if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
|
||||||
transpose = TransposeType.CONV2D
|
transpose = TransposeType.CONV2D
|
||||||
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
|
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
|
||||||
transpose = TransposeType.CONV1D
|
transpose = TransposeType.CONV1D
|
||||||
elif bool(
|
elif bool(
|
||||||
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
|
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
|
||||||
@@ -98,6 +101,43 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
|||||||
return tf_name, transpose
|
return tf_name, transpose
|
||||||
|
|
||||||
|
|
||||||
|
def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
|
||||||
|
"""
|
||||||
|
Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
|
||||||
|
framework agnostic way.
|
||||||
|
"""
|
||||||
|
if transpose is TransposeType.CONV2D:
|
||||||
|
# Conv2D weight:
|
||||||
|
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||||
|
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||||
|
axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
|
||||||
|
weight = transpose_func(weight, axes=axes)
|
||||||
|
elif transpose is TransposeType.CONV1D:
|
||||||
|
# Conv1D weight:
|
||||||
|
# PT: (num_out_channel, num_in_channel, kernel)
|
||||||
|
# -> TF: (kernel, num_in_channel, num_out_channel)
|
||||||
|
weight = transpose_func(weight, axes=(2, 1, 0))
|
||||||
|
elif transpose is TransposeType.SIMPLE:
|
||||||
|
weight = transpose_func(weight)
|
||||||
|
|
||||||
|
if match_shape is None:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
if len(match_shape) < len(weight.shape):
|
||||||
|
weight = squeeze(weight)
|
||||||
|
elif len(match_shape) > len(weight.shape):
|
||||||
|
weight = expand_dims(weight, axis=0)
|
||||||
|
|
||||||
|
if list(match_shape) != list(weight.shape):
|
||||||
|
try:
|
||||||
|
weight = reshape(weight, match_shape)
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (match_shape, match_shape)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# PyTorch => TF 2.0 #
|
# PyTorch => TF 2.0 #
|
||||||
#####################
|
#####################
|
||||||
@@ -155,7 +195,6 @@ def load_pytorch_weights_in_tf2_model(
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
from tensorflow.python.keras import backend as K
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
@@ -163,6 +202,22 @@ def load_pytorch_weights_in_tf2_model(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
return load_pytorch_state_dict_in_tf2_model(
|
||||||
|
tf_model,
|
||||||
|
pt_state_dict,
|
||||||
|
tf_inputs=tf_inputs,
|
||||||
|
allow_missing_keys=allow_missing_keys,
|
||||||
|
output_loading_info=output_loading_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pytorch_state_dict_in_tf2_model(
|
||||||
|
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
|
||||||
|
):
|
||||||
|
"""Load a pytorch state_dict in a TF 2.0 model."""
|
||||||
|
from tensorflow.python.keras import backend as K
|
||||||
|
|
||||||
if tf_inputs is None:
|
if tf_inputs is None:
|
||||||
tf_inputs = tf_model.dummy_inputs
|
tf_inputs = tf_model.dummy_inputs
|
||||||
|
|
||||||
@@ -216,41 +271,9 @@ def load_pytorch_weights_in_tf2_model(
|
|||||||
continue
|
continue
|
||||||
raise AttributeError(f"{name} not found in PyTorch model")
|
raise AttributeError(f"{name} not found in PyTorch model")
|
||||||
|
|
||||||
array = pt_state_dict[name].numpy()
|
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
||||||
|
|
||||||
if transpose is TransposeType.CONV2D:
|
|
||||||
# Conv2D weight:
|
|
||||||
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
|
||||||
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
|
||||||
array = numpy.transpose(array, axes=(2, 3, 1, 0))
|
|
||||||
elif transpose is TransposeType.CONV1D:
|
|
||||||
# Conv1D weight:
|
|
||||||
# PT: (num_out_channel, num_in_channel, kernel)
|
|
||||||
# -> TF: (kernel, num_in_channel, num_out_channel)
|
|
||||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
|
||||||
elif transpose is TransposeType.SIMPLE:
|
|
||||||
array = numpy.transpose(array)
|
|
||||||
|
|
||||||
if len(symbolic_weight.shape) < len(array.shape):
|
|
||||||
array = numpy.squeeze(array)
|
|
||||||
elif len(symbolic_weight.shape) > len(array.shape):
|
|
||||||
array = numpy.expand_dims(array, axis=0)
|
|
||||||
|
|
||||||
if list(symbolic_weight.shape) != list(array.shape):
|
|
||||||
try:
|
|
||||||
array = numpy.reshape(array, symbolic_weight.shape)
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (symbolic_weight.shape, array.shape)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert list(symbolic_weight.shape) == list(array.shape)
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (symbolic_weight.shape, array.shape)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
tf_loaded_numel += array.size
|
tf_loaded_numel += array.size
|
||||||
# logger.warning(f"Initialize TF weight {symbolic_weight.name}")
|
|
||||||
|
|
||||||
weight_value_tuples.append((symbolic_weight, array))
|
weight_value_tuples.append((symbolic_weight, array))
|
||||||
all_pytorch_weights.discard(name)
|
all_pytorch_weights.discard(name)
|
||||||
@@ -370,6 +393,15 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
|
||||||
|
return load_tf2_state_dict_in_pytorch_model(
|
||||||
|
pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
|
||||||
|
import torch
|
||||||
|
|
||||||
new_pt_params_dict = {}
|
new_pt_params_dict = {}
|
||||||
current_pt_params_dict = dict(pt_model.named_parameters())
|
current_pt_params_dict = dict(pt_model.named_parameters())
|
||||||
|
|
||||||
@@ -381,11 +413,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||||||
|
|
||||||
# Build a map from potential PyTorch weight names to TF 2.0 Variables
|
# Build a map from potential PyTorch weight names to TF 2.0 Variables
|
||||||
tf_weights_map = {}
|
tf_weights_map = {}
|
||||||
for tf_weight in tf_weights:
|
for name, tf_weight in tf_state_dict.items():
|
||||||
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||||
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
|
name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
|
||||||
)
|
)
|
||||||
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
|
tf_weights_map[pt_name] = (tf_weight, transpose)
|
||||||
|
|
||||||
all_tf_weights = set(list(tf_weights_map.keys()))
|
all_tf_weights = set(list(tf_weights_map.keys()))
|
||||||
loaded_pt_weights_data_ptr = {}
|
loaded_pt_weights_data_ptr = {}
|
||||||
@@ -406,43 +438,18 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||||||
|
|
||||||
array, transpose = tf_weights_map[pt_weight_name]
|
array, transpose = tf_weights_map[pt_weight_name]
|
||||||
|
|
||||||
if transpose is TransposeType.CONV2D:
|
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
|
||||||
# Conv2D weight:
|
|
||||||
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
|
||||||
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
|
||||||
array = numpy.transpose(array, axes=(3, 2, 0, 1))
|
|
||||||
elif transpose is TransposeType.CONV1D:
|
|
||||||
# Conv1D weight:
|
|
||||||
# TF: (kernel, num_in_channel, num_out_channel)
|
|
||||||
# -> PT: (num_out_channel, num_in_channel, kernel)
|
|
||||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
|
||||||
elif transpose is TransposeType.SIMPLE:
|
|
||||||
array = numpy.transpose(array)
|
|
||||||
|
|
||||||
if len(pt_weight.shape) < len(array.shape):
|
|
||||||
array = numpy.squeeze(array)
|
|
||||||
elif len(pt_weight.shape) > len(array.shape):
|
|
||||||
array = numpy.expand_dims(array, axis=0)
|
|
||||||
|
|
||||||
if list(pt_weight.shape) != list(array.shape):
|
|
||||||
try:
|
|
||||||
array = numpy.reshape(array, pt_weight.shape)
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pt_weight.shape, array.shape)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert list(pt_weight.shape) == list(array.shape)
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pt_weight.shape, array.shape)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# logger.warning(f"Initialize PyTorch weight {pt_weight_name}")
|
|
||||||
# Make sure we have a proper numpy array
|
|
||||||
if numpy.isscalar(array):
|
if numpy.isscalar(array):
|
||||||
array = numpy.array(array)
|
array = numpy.array(array)
|
||||||
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
|
if not is_torch_tensor(array) and not is_numpy_array(array):
|
||||||
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
|
array = array.numpy()
|
||||||
|
if is_numpy_array(array):
|
||||||
|
# Convert to torch tensor
|
||||||
|
array = torch.from_numpy(array)
|
||||||
|
|
||||||
|
new_pt_params_dict[pt_weight_name] = array
|
||||||
|
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
|
||||||
all_tf_weights.discard(pt_weight_name)
|
all_tf_weights.discard(pt_weight_name)
|
||||||
|
|
||||||
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .generic import (
|
|||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
TensorType,
|
TensorType,
|
||||||
cached_property,
|
cached_property,
|
||||||
|
expand_dims,
|
||||||
find_labels,
|
find_labels,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
is_jax_tensor,
|
is_jax_tensor,
|
||||||
@@ -46,8 +47,11 @@ from .generic import (
|
|||||||
is_tf_tensor,
|
is_tf_tensor,
|
||||||
is_torch_device,
|
is_torch_device,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
|
reshape,
|
||||||
|
squeeze,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
|
transpose,
|
||||||
working_or_temp_dir,
|
working_or_temp_dir,
|
||||||
)
|
)
|
||||||
from .hub import (
|
from .hub import (
|
||||||
|
|||||||
@@ -29,6 +29,13 @@ import numpy as np
|
|||||||
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
class cached_property(property):
|
class cached_property(property):
|
||||||
"""
|
"""
|
||||||
Descriptor that mimics @property but caches output in member variable.
|
Descriptor that mimics @property but caches output in member variable.
|
||||||
@@ -370,3 +377,71 @@ def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
|
|||||||
yield tmp_dir
|
yield tmp_dir
|
||||||
else:
|
else:
|
||||||
yield working_dir
|
yield working_dir
|
||||||
|
|
||||||
|
|
||||||
|
def transpose(array, axes=None):
|
||||||
|
"""
|
||||||
|
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||||
|
arrays.
|
||||||
|
"""
|
||||||
|
if is_numpy_array(array):
|
||||||
|
return np.transpose(array, axes=axes)
|
||||||
|
elif is_torch_tensor(array):
|
||||||
|
return array.T if axes is None else array.permute(*axes)
|
||||||
|
elif is_tf_tensor(array):
|
||||||
|
return tf.transpose(array, perm=axes)
|
||||||
|
elif is_jax_tensor(array):
|
||||||
|
return jnp.transpose(array, axes=axes)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Type not supported for transpose: {type(array)}.")
|
||||||
|
|
||||||
|
|
||||||
|
def reshape(array, newshape):
|
||||||
|
"""
|
||||||
|
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||||
|
arrays.
|
||||||
|
"""
|
||||||
|
if is_numpy_array(array):
|
||||||
|
return np.reshape(array, newshape)
|
||||||
|
elif is_torch_tensor(array):
|
||||||
|
return array.reshape(*newshape)
|
||||||
|
elif is_tf_tensor(array):
|
||||||
|
return tf.reshape(array, newshape)
|
||||||
|
elif is_jax_tensor(array):
|
||||||
|
return jnp.reshape(array, newshape)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Type not supported for reshape: {type(array)}.")
|
||||||
|
|
||||||
|
|
||||||
|
def squeeze(array, axis=None):
|
||||||
|
"""
|
||||||
|
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||||
|
arrays.
|
||||||
|
"""
|
||||||
|
if is_numpy_array(array):
|
||||||
|
return np.squeeze(array, axis=axis)
|
||||||
|
elif is_torch_tensor(array):
|
||||||
|
return array.squeeze() if axis is None else array.squeeze(dim=axis)
|
||||||
|
elif is_tf_tensor(array):
|
||||||
|
return tf.squeeze(array, axis=axis)
|
||||||
|
elif is_jax_tensor(array):
|
||||||
|
return jnp.squeeze(array, axis=axis)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
|
||||||
|
|
||||||
|
|
||||||
|
def expand_dims(array, axis):
|
||||||
|
"""
|
||||||
|
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||||||
|
arrays.
|
||||||
|
"""
|
||||||
|
if is_numpy_array(array):
|
||||||
|
return np.expand_dims(array, axis)
|
||||||
|
elif is_torch_tensor(array):
|
||||||
|
return array.unsqueeze(dim=axis)
|
||||||
|
elif is_tf_tensor(array):
|
||||||
|
return tf.expand_dims(array, axis=axis)
|
||||||
|
elif is_jax_tensor(array):
|
||||||
|
return jnp.expand_dims(array, axis=axis)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||||||
|
|||||||
@@ -15,7 +15,29 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.utils import flatten_dict
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.testing_utils import require_flax, require_tf, require_torch
|
||||||
|
from transformers.utils import (
|
||||||
|
expand_dims,
|
||||||
|
flatten_dict,
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
reshape,
|
||||||
|
squeeze,
|
||||||
|
transpose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class GenericTester(unittest.TestCase):
|
class GenericTester(unittest.TestCase):
|
||||||
@@ -43,3 +65,136 @@ class GenericTester(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.assertEqual(flatten_dict(input_dict), expected_dict)
|
self.assertEqual(flatten_dict(input_dict), expected_dict)
|
||||||
|
|
||||||
|
def test_transpose_numpy(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
self.assertTrue(np.allclose(transpose(x), x.transpose()))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_transpose_torch(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_transpose_tf(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_transpose_flax(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x), np.asarray(transpose(t))))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), np.asarray(transpose(t, axes=(1, 2, 0)))))
|
||||||
|
|
||||||
|
def test_reshape_numpy(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.reshape(x, (4, 3))))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.reshape(x, (12, 5))))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_reshape_torch(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_reshape_tf(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_reshape_flax(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.asarray(reshape(t, (4, 3)))))
|
||||||
|
|
||||||
|
x = np.random.randn(3, 4, 5)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.asarray(reshape(t, (12, 5)))))
|
||||||
|
|
||||||
|
def test_squeeze_numpy(self):
|
||||||
|
x = np.random.randn(1, 3, 4)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x), np.squeeze(x)))
|
||||||
|
|
||||||
|
x = np.random.randn(1, 4, 1, 5)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x, axis=2), np.squeeze(x, axis=2)))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_squeeze_torch(self):
|
||||||
|
x = np.random.randn(1, 3, 4)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(1, 4, 1, 5)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_squeeze_tf(self):
|
||||||
|
x = np.random.randn(1, 3, 4)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
|
||||||
|
|
||||||
|
x = np.random.randn(1, 4, 1, 5)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_squeeze_flax(self):
|
||||||
|
x = np.random.randn(1, 3, 4)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x), np.asarray(squeeze(t))))
|
||||||
|
|
||||||
|
x = np.random.randn(1, 4, 1, 5)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(squeeze(x, axis=2), np.asarray(squeeze(t, axis=2))))
|
||||||
|
|
||||||
|
def test_expand_dims_numpy(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_expand_dims_torch(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = torch.tensor(x)
|
||||||
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_expand_dims_tf(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = tf.constant(x)
|
||||||
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_expand_dims_flax(self):
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
t = jnp.array(x)
|
||||||
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
||||||
|
|||||||
Reference in New Issue
Block a user