Wav2Vec2 models must either throw or deal with add_apater (#15409)
* Wav2Vec2 models must either throw or deal with add_apater Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add pre-add_adapter backwards compatibility * Add pre-add_adapter backwards compatibility * Fix issue in tests/test_modeling_wav2vec2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -202,6 +202,17 @@ class Wav2Vec2ModelTester:
|
||||
result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_model_with_adapter_for_ctc(self, config, input_values, attention_mask):
|
||||
config.add_adapter = True
|
||||
config.output_hidden_size = 2 * config.hidden_size
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
|
||||
)
|
||||
|
||||
def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
|
||||
config.add_adapter = True
|
||||
config.output_hidden_size = 8
|
||||
@@ -414,6 +425,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter_for_ctc(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter_proj_dim(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user