[Whisper] Make tests faster (#24105)
This commit is contained in:
@@ -95,7 +95,7 @@ class WhisperModelTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
seq_length=1500,
|
seq_length=60,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_labels=False,
|
use_labels=False,
|
||||||
vocab_size=200,
|
vocab_size=200,
|
||||||
@@ -107,7 +107,7 @@ class WhisperModelTester:
|
|||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=20,
|
||||||
max_source_positions=750,
|
max_source_positions=30,
|
||||||
max_target_positions=40,
|
max_target_positions=40,
|
||||||
bos_token_id=98,
|
bos_token_id=98,
|
||||||
eos_token_id=98,
|
eos_token_id=98,
|
||||||
@@ -1538,7 +1538,7 @@ class WhisperEncoderModelTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
seq_length=3000,
|
seq_length=60,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
@@ -1549,7 +1549,7 @@ class WhisperEncoderModelTester:
|
|||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=20,
|
max_position_embeddings=20,
|
||||||
max_source_positions=1500,
|
max_source_positions=30,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
num_conv_layers=1,
|
num_conv_layers=1,
|
||||||
suppress_tokens=None,
|
suppress_tokens=None,
|
||||||
@@ -1731,3 +1731,156 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
# no flax model exists for this class
|
||||||
|
return
|
||||||
|
|
||||||
|
# Output all for aggressive testing
|
||||||
|
config.output_hidden_states = True
|
||||||
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||||
|
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
# send pytorch inputs to the correct device
|
||||||
|
pt_inputs = {
|
||||||
|
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# convert inputs to Flax
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
|
fx_model.params = fx_state
|
||||||
|
|
||||||
|
# send pytorch model to the correct device
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
|
fx_outputs = fx_model(**fx_inputs)
|
||||||
|
|
||||||
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
|
||||||
|
|
||||||
|
fx_outputs_loaded = fx_model_loaded(**fx_inputs)
|
||||||
|
|
||||||
|
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_flax_to_pt(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
init_shape = (1,) + inputs_dict["input_features"].shape[1:]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
fx_model_class_name = "Flax" + model_class.__name__
|
||||||
|
|
||||||
|
if not hasattr(transformers, fx_model_class_name):
|
||||||
|
# no flax model exists for this class
|
||||||
|
return
|
||||||
|
|
||||||
|
# Output all for aggressive testing
|
||||||
|
config.output_hidden_states = True
|
||||||
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
|
fx_model_class = getattr(transformers, fx_model_class_name)
|
||||||
|
|
||||||
|
# load PyTorch class
|
||||||
|
pt_model = model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
|
|
||||||
|
# load Flax class
|
||||||
|
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
|
||||||
|
|
||||||
|
# make sure only flax inputs are forward that actually exist in function args
|
||||||
|
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
# remove function args that don't exist in Flax
|
||||||
|
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
|
||||||
|
|
||||||
|
# send pytorch inputs to the correct device
|
||||||
|
pt_inputs = {
|
||||||
|
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# convert inputs to Flax
|
||||||
|
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
# make sure weights are tied in PyTorch
|
||||||
|
pt_model.tie_weights()
|
||||||
|
|
||||||
|
# send pytorch model to the correct device
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
|
fx_outputs = fx_model(**fx_inputs)
|
||||||
|
|
||||||
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fx_model.save_pretrained(tmpdirname)
|
||||||
|
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
|
# send pytorch model to the correct device
|
||||||
|
pt_model_loaded.to(torch_device)
|
||||||
|
pt_model_loaded.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||||
|
|
||||||
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||||
|
|||||||
Reference in New Issue
Block a user