Output None as attention when layer is skipped (#30597)
* Output `None` as attention when layer is skipped * Add test for output_attentions
This commit is contained in:
@@ -727,7 +727,7 @@ class WavLMEncoder(nn.Module):
|
|||||||
hidden_states, position_bias = layer_outputs[:2]
|
hidden_states, position_bias = layer_outputs[:2]
|
||||||
|
|
||||||
if skip_the_layer:
|
if skip_the_layer:
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None, None)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
||||||
@@ -810,7 +810,7 @@ class WavLMEncoderStableLayerNorm(nn.Module):
|
|||||||
hidden_states, position_bias = layer_outputs[:2]
|
hidden_states, position_bias = layer_outputs[:2]
|
||||||
|
|
||||||
if skip_the_layer:
|
if skip_the_layer:
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None, None)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
||||||
|
|||||||
@@ -288,6 +288,15 @@ class WavLMModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_output_attentions(self, config, input_values, attention_mask):
|
||||||
|
model = WavLMModel(config=config)
|
||||||
|
model.config.layerdrop = 1.0
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask, output_attentions=True)
|
||||||
|
self.parent.assertTrue(len(outputs.attentions) > 0)
|
||||||
|
|
||||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
model = WavLMForCTC(config)
|
model = WavLMForCTC(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -354,6 +363,10 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_output_attentions(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_output_attentions(*config_and_inputs)
|
||||||
|
|
||||||
def test_labels_out_of_vocab(self):
|
def test_labels_out_of_vocab(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user