Add TF<>PT and Flax<>PT everywhere (#14047)
* up * up * up * up * up * up * up * add clip * fix clip PyTorch * fix clip PyTorch * up * up * up * up * up * up * up
This commit is contained in:
committed by
GitHub
parent
8560b55b5e
commit
0c3174c758
@@ -22,7 +22,14 @@ import pytest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import Wav2Vec2Config, is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_datasets,
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
@@ -131,6 +138,7 @@ class Wav2Vec2ModelTester:
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
intermediate_size=self.intermediate_size,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
do_stable_layer_norm=self.do_stable_layer_norm,
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
@@ -357,6 +365,16 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
# non-robust architecture does not exist in Flax
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
pass
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
# non-robust architecture does not exist in Flax
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
||||
Reference in New Issue
Block a user