[Whiper] add get_input_embeddings to WhisperForAudioClassification (#22133)
* add `get_input_embeddings` to `WhisperForAudioClassification` * add common tests * fix another common test * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
return config, input_ids, None, max_length
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
decoder_input_ids = inputs.pop("decoder_input_ids", None)
|
||||
inputs.pop("decoder_attention_mask", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# training is not supported yet
|
||||
def test_training(self):
|
||||
@@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
|
||||
self.assertTrue((outputs_embeds == outputs).all())
|
||||
|
||||
# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
|
||||
# Needs to override as the encoder input embedding is a Conv1d
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
|
||||
model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
|
||||
|
||||
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
|
||||
Reference in New Issue
Block a user