Add TF<>PT and Flax<>PT everywhere (#14047)

* up

* up

* up

* up

* up

* up

* up

* add clip

* fix clip PyTorch

* fix clip PyTorch

* up

* up

* up

* up

* up

* up

* up
This commit is contained in:
Patrick von Platen
2021-10-25 23:55:08 +02:00
committed by GitHub
parent 8560b55b5e
commit 0c3174c758
9 changed files with 589 additions and 45 deletions

View File

@@ -22,7 +22,14 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_datasets,
require_soundfile,
require_torch,
slow,
torch_device,
)
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
@@ -131,6 +138,7 @@ class Wav2Vec2ModelTester:
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
do_stable_layer_norm=self.do_stable_layer_norm,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
@@ -357,6 +365,16 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True