Add TF<>PT and Flax<>PT everywhere (#14047)
* up * up * up * up * up * up * up * add clip * fix clip PyTorch * fix clip PyTorch * up * up * up * up * up * up * up
This commit is contained in:
committed by
GitHub
parent
8560b55b5e
commit
0c3174c758
@@ -314,35 +314,13 @@ class FlaxAlbertLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class FlaxAlbertLayers(nn.Module):
|
|
||||||
config: AlbertConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
||||||
layer_index: Optional[str] = None
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.albert_layers = FlaxAlbertLayer(self.config, name=self.layer_index, dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
deterministic: bool = True,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
):
|
|
||||||
outputs = self.albert_layers(
|
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
|
||||||
)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxAlbertLayerCollection(nn.Module):
|
class FlaxAlbertLayerCollection(nn.Module):
|
||||||
config: AlbertConfig
|
config: AlbertConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.layers = [
|
self.layers = [
|
||||||
FlaxAlbertLayers(self.config, name="albert_layers", layer_index=str(i), dtype=self.dtype)
|
FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
|
||||||
for i in range(self.config.inner_group_num)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -385,7 +363,7 @@ class FlaxAlbertLayerCollections(nn.Module):
|
|||||||
layer_index: Optional[str] = None
|
layer_index: Optional[str] = None
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.albert_layer_groups = FlaxAlbertLayerCollection(self.config, name=self.layer_index, dtype=self.dtype)
|
self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -395,7 +373,7 @@ class FlaxAlbertLayerCollections(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
outputs = self.albert_layer_groups(
|
outputs = self.albert_layers(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
@@ -405,19 +383,13 @@ class FlaxAlbertLayerCollections(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class FlaxAlbertEncoder(nn.Module):
|
class FlaxAlbertLayerGroups(nn.Module):
|
||||||
config: AlbertConfig
|
config: AlbertConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
|
||||||
self.embedding_hidden_mapping_in = nn.Dense(
|
|
||||||
self.config.hidden_size,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.layers = [
|
self.layers = [
|
||||||
FlaxAlbertLayerCollections(self.config, name="albert_layer_groups", layer_index=str(i), dtype=self.dtype)
|
FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
|
||||||
for i in range(self.config.num_hidden_groups)
|
for i in range(self.config.num_hidden_groups)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -430,7 +402,6 @@ class FlaxAlbertEncoder(nn.Module):
|
|||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
||||||
|
|
||||||
@@ -459,6 +430,37 @@ class FlaxAlbertEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAlbertEncoder(nn.Module):
|
||||||
|
config: AlbertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.embedding_hidden_mapping_in = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||||
|
return self.albert_layer_groups(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxAlbertOnlyMLMHead(nn.Module):
|
class FlaxAlbertOnlyMLMHead(nn.Module):
|
||||||
config: AlbertConfig
|
config: AlbertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|||||||
@@ -1222,7 +1222,10 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
|
|||||||
if inputs["attention_mask"] is not None:
|
if inputs["attention_mask"] is not None:
|
||||||
# compute real output lengths according to convolution formula
|
# compute real output lengths according to convolution formula
|
||||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
||||||
attention_mask = tf.sequence_mask(output_lengths, dtype=hidden_states.dtype)
|
|
||||||
|
attention_mask = tf.sequence_mask(
|
||||||
|
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
|
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class BigBirdModelTester:
|
|||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
hidden_act="gelu_fast",
|
hidden_act="gelu_new",
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=256,
|
max_position_embeddings=256,
|
||||||
|
|||||||
@@ -23,9 +23,17 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import transformers
|
||||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||||
from transformers.file_utils import is_torch_available, is_vision_available
|
from transformers.file_utils import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
is_flax_available,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, random_attention_mask
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, random_attention_mask
|
||||||
@@ -45,6 +53,14 @@ if is_vision_available():
|
|||||||
from transformers import CLIPProcessor
|
from transformers import CLIPProcessor
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
|
convert_pytorch_state_dict_to_flax,
|
||||||
|
load_flax_weights_in_pytorch_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModelTester:
|
class CLIPVisionModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -330,6 +346,13 @@ class CLIPTextModelTester:
|
|||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
if input_mask is not None:
|
||||||
|
batch_size, seq_length = input_mask.shape
|
||||||
|
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||||
|
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||||
|
input_mask[batch_idx, :start_index] = 1
|
||||||
|
input_mask[batch_idx, start_index:] = 0
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
|
||||||
return config, input_ids, input_mask
|
return config, input_ids, input_mask
|
||||||
@@ -558,6 +581,125 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# overwrite from common since FlaxCLIPModel returns nested output
|
||||||
|
# which is not supported in the common test
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
|
||||||
|
# load PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
|
fx_model.params = fx_state
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
# convert inputs to Flax
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
# overwrite from common since FlaxCLIPModel returns nested output
|
||||||
|
# which is not supported in the common test
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_flax_to_pt(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
# load corresponding PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
# no flax model exists for this class
|
||||||
|
return
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
# make sure weights are tied in PyTorch
|
||||||
|
pt_model.tie_weights()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|
||||||
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fx_model.save_pretrained(tmpdirname)
|
||||||
|
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -25,10 +25,13 @@ import unittest
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import transformers
|
||||||
from huggingface_hub import HfApi, Repository
|
from huggingface_hub import HfApi, Repository
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
@@ -36,6 +39,8 @@ from transformers.testing_utils import (
|
|||||||
USER,
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@@ -45,7 +50,6 @@ from transformers.testing_utils import (
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -70,6 +74,13 @@ if is_torch_available():
|
|||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
|
convert_pytorch_state_dict_to_flax,
|
||||||
|
load_flax_weights_in_pytorch_model,
|
||||||
|
)
|
||||||
|
|
||||||
if is_torch_fx_available():
|
if is_torch_fx_available():
|
||||||
from transformers.utils.fx import symbolic_trace
|
from transformers.utils.fx import symbolic_trace
|
||||||
|
|
||||||
@@ -1417,6 +1428,241 @@ class ModelTesterMixin:
|
|||||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||||
|
|
||||||
|
if not hasattr(transformers, tf_model_class_name):
|
||||||
|
# transformers does not have TF version yet
|
||||||
|
return
|
||||||
|
|
||||||
|
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||||
|
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
tf_model = tf_model_class(config)
|
||||||
|
pt_model = model_class(config)
|
||||||
|
|
||||||
|
# make sure only tf inputs are forward that actually exist in function args
|
||||||
|
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
|
||||||
|
|
||||||
|
# remove all head masks
|
||||||
|
tf_input_keys.discard("head_mask")
|
||||||
|
tf_input_keys.discard("cross_attn_head_mask")
|
||||||
|
tf_input_keys.discard("decoder_head_mask")
|
||||||
|
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
tf_inputs_dict = {}
|
||||||
|
for key, tensor in pt_inputs.items():
|
||||||
|
# skip key that does not exist in tf
|
||||||
|
if type(tensor) == bool:
|
||||||
|
tf_inputs_dict[key] = tensor
|
||||||
|
elif key == "input_values":
|
||||||
|
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||||
|
else:
|
||||||
|
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||||
|
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||||
|
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
|
# need to rename encoder-decoder "inputs" for PyTorch
|
||||||
|
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||||
|
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(**pt_inputs)
|
||||||
|
tfo = tf_model(tf_inputs_dict, training=False)
|
||||||
|
|
||||||
|
tf_hidden_states = tfo[0].numpy()
|
||||||
|
pt_hidden_states = pto[0].numpy()
|
||||||
|
|
||||||
|
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||||
|
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||||
|
|
||||||
|
pt_hidden_states[tf_nans] = 0
|
||||||
|
tf_hidden_states[tf_nans] = 0
|
||||||
|
pt_hidden_states[pt_nans] = 0
|
||||||
|
tf_hidden_states[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||||
|
self.assertLessEqual(max_diff, 4e-2)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||||
|
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||||
|
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||||
|
|
||||||
|
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||||
|
tf_model.save_weights(tf_checkpoint_path)
|
||||||
|
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
tf_inputs_dict = {}
|
||||||
|
for key, tensor in pt_inputs.items():
|
||||||
|
# skip key that does not exist in tf
|
||||||
|
if type(tensor) == bool:
|
||||||
|
tensor = np.array(tensor, dtype=bool)
|
||||||
|
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
|
||||||
|
elif key == "input_values":
|
||||||
|
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||||
|
else:
|
||||||
|
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||||
|
|
||||||
|
# need to rename encoder-decoder "inputs" for PyTorch
|
||||||
|
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||||
|
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(**pt_inputs)
|
||||||
|
|
||||||
|
tfo = tf_model(tf_inputs_dict)
|
||||||
|
tfo = tfo[0].numpy()
|
||||||
|
pto = pto[0].numpy()
|
||||||
|
tf_nans = np.copy(np.isnan(tfo))
|
||||||
|
pt_nans = np.copy(np.isnan(pto))
|
||||||
|
|
||||||
|
pto[tf_nans] = 0
|
||||||
|
tfo[tf_nans] = 0
|
||||||
|
pto[pt_nans] = 0
|
||||||
|
tfo[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(tfo - pto))
|
||||||
|
self.assertLessEqual(max_diff, 4e-2)
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
|
diff = np.abs((a - b)).max()
|
||||||
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
|
||||||
|
# load PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
|
fx_model.params = fx_state
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
# convert inputs to Flax
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_flax_to_pt(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
# load corresponding PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
# no flax model exists for this class
|
||||||
|
return
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, dtype=jnp.float32)
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
# make sure weights are tied in PyTorch
|
||||||
|
pt_model.tie_weights()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fx_model.save_pretrained(tmpdirname)
|
||||||
|
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class LongformerModelTester:
|
|||||||
# because its local attention only attends to `self.attention_window + 1` locations
|
# because its local attention only attends to `self.attention_window + 1` locations
|
||||||
# (assuming no token with global attention, otherwise the last dimension of attentions
|
# (assuming no token with global attention, otherwise the last dimension of attentions
|
||||||
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
||||||
self.key_length = self.attention_window + 1
|
self.key_length = self.attention_window + 2
|
||||||
|
|
||||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||||
@@ -243,6 +243,8 @@ class LongformerModelTester:
|
|||||||
choice_labels,
|
choice_labels,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
global_attention_mask = torch.zeros_like(input_ids)
|
global_attention_mask = torch.zeros_like(input_ids)
|
||||||
|
global_attention_mask[:, -1] = 1
|
||||||
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"token_type_ids": token_type_ids,
|
"token_type_ids": token_type_ids,
|
||||||
|
|||||||
@@ -15,13 +15,16 @@
|
|||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import LxmertConfig, is_torch_available
|
import transformers
|
||||||
|
from transformers import LxmertConfig, is_tf_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
@@ -40,6 +43,10 @@ if is_torch_available():
|
|||||||
from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class LxmertModelTester:
|
class LxmertModelTester:
|
||||||
"""You can also import this e.g from .test_modeling_bart import BartModelTester"""
|
"""You can also import this e.g from .test_modeling_bart import BartModelTester"""
|
||||||
|
|
||||||
@@ -496,7 +503,7 @@ class LxmertModelTester:
|
|||||||
result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels)
|
result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels)
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
@@ -520,6 +527,9 @@ class LxmertModelTester:
|
|||||||
"attention_mask": input_mask,
|
"attention_mask": input_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if return_obj_labels:
|
||||||
|
inputs_dict["obj_labels"] = obj_labels
|
||||||
|
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -732,6 +742,128 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(hidden_states_vision.grad)
|
self.assertIsNotNone(hidden_states_vision.grad)
|
||||||
self.assertIsNotNone(attentions_vision.grad)
|
self.assertIsNotNone(attentions_vision.grad)
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
|
||||||
|
return_obj_labels="PreTraining" in model_class.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||||
|
|
||||||
|
if not hasattr(transformers, tf_model_class_name):
|
||||||
|
# transformers does not have TF version yet
|
||||||
|
return
|
||||||
|
|
||||||
|
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||||
|
|
||||||
|
config.output_hidden_states = True
|
||||||
|
config.task_obj_predict = False
|
||||||
|
|
||||||
|
pt_model = model_class(config)
|
||||||
|
tf_model = tf_model_class(config)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
def recursive_numpy_convert(iterable):
|
||||||
|
return_dict = {}
|
||||||
|
for key, value in iterable.items():
|
||||||
|
if type(value) == bool:
|
||||||
|
return_dict[key] = value
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return_dict[key] = recursive_numpy_convert(value)
|
||||||
|
else:
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return_dict[key] = (
|
||||||
|
tf.convert_to_tensor(iter_value.numpy(), dtype=tf.int32) for iter_value in value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return_dict[key] = tf.convert_to_tensor(value.numpy(), dtype=tf.int32)
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
|
||||||
|
|
||||||
|
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||||
|
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
|
||||||
|
# Delete obj labels as we want to compute the hidden states and not the loss
|
||||||
|
|
||||||
|
if "obj_labels" in inputs_dict:
|
||||||
|
del inputs_dict["obj_labels"]
|
||||||
|
|
||||||
|
def torch_type(key):
|
||||||
|
if key in ("visual_feats", "visual_pos"):
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
return torch.long
|
||||||
|
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(**pt_inputs)
|
||||||
|
tfo = tf_model(tf_inputs_dict, training=False)
|
||||||
|
tf_hidden_states = tfo[0].numpy()
|
||||||
|
pt_hidden_states = pto[0].numpy()
|
||||||
|
|
||||||
|
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||||
|
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||||
|
|
||||||
|
pt_hidden_states[tf_nans] = 0
|
||||||
|
tf_hidden_states[tf_nans] = 0
|
||||||
|
pt_hidden_states[pt_nans] = 0
|
||||||
|
tf_hidden_states[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||||
|
# Debug info (remove when fixed)
|
||||||
|
if max_diff >= 2e-2:
|
||||||
|
print("===")
|
||||||
|
print(model_class)
|
||||||
|
print(config)
|
||||||
|
print(inputs_dict)
|
||||||
|
print(pt_inputs)
|
||||||
|
self.assertLessEqual(max_diff, 6e-2)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||||
|
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||||
|
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||||
|
|
||||||
|
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||||
|
tf_model.save_weights(tf_checkpoint_path)
|
||||||
|
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
|
||||||
|
for key, value in pt_inputs.items():
|
||||||
|
if key in ("visual_feats", "visual_pos"):
|
||||||
|
pt_inputs[key] = value.to(torch.float32)
|
||||||
|
else:
|
||||||
|
pt_inputs[key] = value.to(torch.long)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(**pt_inputs)
|
||||||
|
|
||||||
|
tfo = tf_model(tf_inputs_dict)
|
||||||
|
tfo = tfo[0].numpy()
|
||||||
|
pto = pto[0].numpy()
|
||||||
|
tf_nans = np.copy(np.isnan(tfo))
|
||||||
|
pt_nans = np.copy(np.isnan(pto))
|
||||||
|
|
||||||
|
pto[tf_nans] = 0
|
||||||
|
tfo[tf_nans] = 0
|
||||||
|
pto[pt_nans] = 0
|
||||||
|
tfo[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(tfo - pto))
|
||||||
|
self.assertLessEqual(max_diff, 6e-2)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LxmertModelIntegrationTest(unittest.TestCase):
|
class LxmertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -431,7 +431,6 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
@is_pt_tf_cross_test
|
@is_pt_tf_cross_test
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -22,7 +22,14 @@ import pytest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import Wav2Vec2Config, is_torch_available
|
from transformers import Wav2Vec2Config, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
require_datasets,
|
||||||
|
require_soundfile,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -131,6 +138,7 @@ class Wav2Vec2ModelTester:
|
|||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
intermediate_size=self.intermediate_size,
|
intermediate_size=self.intermediate_size,
|
||||||
layer_norm_eps=self.layer_norm_eps,
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
|
do_stable_layer_norm=self.do_stable_layer_norm,
|
||||||
hidden_act=self.hidden_act,
|
hidden_act=self.hidden_act,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
@@ -357,6 +365,16 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
# non-robust architecture does not exist in Flax
|
||||||
|
def test_equivalence_flax_to_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
# non-robust architecture does not exist in Flax
|
||||||
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
|
|||||||
Reference in New Issue
Block a user