[Speech] Refactor Examples (#14040)
* adapt_examples * up * up * up * up * add auto models * finish
This commit is contained in:
committed by
GitHub
parent
2024faf171
commit
d5ff69fce9
@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SEWForCTC, SEWModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
SEWForCTC,
|
||||
SEWForSequenceClassification,
|
||||
SEWModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
@@ -219,6 +225,54 @@ class SEWModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = SEWForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = SEWForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = SEWForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -241,7 +295,7 @@ class SEWModelTester:
|
||||
|
||||
@require_torch
|
||||
class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SEWForCTC, SEWModel) if is_torch_available() else ()
|
||||
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -328,6 +382,14 @@ class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import SEWDForCTC, SEWDModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
SEWDForCTC,
|
||||
SEWDForSequenceClassification,
|
||||
SEWDModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
@@ -240,6 +246,54 @@ class SEWDModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = SEWDForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = SEWDForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = SEWDForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -262,7 +316,7 @@ class SEWDModelTester:
|
||||
|
||||
@require_torch
|
||||
class SEWDModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SEWDForCTC, SEWDModel) if is_torch_available() else ()
|
||||
all_model_classes = (SEWDForCTC, SEWDModel, SEWDForSequenceClassification) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
||||
Reference in New Issue
Block a user