[Whisper] Fix audio classification with weighted layer sum (#28563)
* fix * tests * fix test
This commit is contained in:
@@ -57,6 +57,8 @@ if is_flash_attn_2_available():
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "WhisperConfig"
|
_CONFIG_FOR_DOC = "WhisperConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
||||||
|
|
||||||
@@ -2957,6 +2959,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
|
|||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
output_hidden_states = True
|
||||||
|
elif output_hidden_states is None:
|
||||||
|
output_hidden_states = self.config.output_hidden_states
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
@@ -2969,7 +2976,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_weighted_layer_sum:
|
if self.config.use_weighted_layer_sum:
|
||||||
hidden_states = torch.stack(encoder_outputs, dim=1)
|
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2292,16 +2292,15 @@ class WhisperEncoderModelTester:
|
|||||||
def encoder_seq_length(self):
|
def encoder_seq_length(self):
|
||||||
return self.get_subsampled_output_lengths(self.seq_length)
|
return self.get_subsampled_output_lengths(self.seq_length)
|
||||||
|
|
||||||
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
|
def create_and_check_model_forward(self, config, inputs_dict, use_weighted_layer_sum=False):
|
||||||
model = WhisperForAudioClassification(config=config).to(torch_device).eval()
|
config.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
|
model = WhisperForAudioClassification(config=config)
|
||||||
if freeze_encoder:
|
model.to(torch_device).eval()
|
||||||
model.freeze_encoder()
|
|
||||||
|
|
||||||
input_features = inputs_dict["input_features"]
|
input_features = inputs_dict["input_features"]
|
||||||
|
|
||||||
# first forward pass
|
with torch.no_grad():
|
||||||
last_hidden_state = model(input_features).logits
|
last_hidden_state = model(input_features).logits
|
||||||
|
|
||||||
self.parent.assertTrue(last_hidden_state.shape, (13, 2))
|
self.parent.assertTrue(last_hidden_state.shape, (13, 2))
|
||||||
|
|
||||||
@@ -2336,6 +2335,14 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
|
expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
|
||||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
|
||||||
|
def test_forward_pass(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_forward_pass_weighted_layer_sum(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True)
|
||||||
|
|
||||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||||
def test_cpu_offload(self):
|
def test_cpu_offload(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user