Add TensorFlow Wav2Vec2 for sequence classification (#22073)
* Add initial changes for TF wav2vec2 for sequence classification * Add suggested changes * Add serving and serving output methods * Add serving_output implementation and fix layer_weights * Add fixes * Fixed test cases * Fixing test and adding suggested changes
This commit is contained in:
@@ -50,7 +50,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
TFWav2Vec2ForCTC,
|
||||
TFWav2Vec2ForSequenceClassification,
|
||||
TFWav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
|
||||
|
||||
|
||||
@@ -247,6 +253,29 @@ class TFWav2Vec2ModelTester:
|
||||
|
||||
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
|
||||
|
||||
def check_seq_classifier_loss(self, loss, config, input_values, *args):
|
||||
model = TFWav2Vec2ForSequenceClassification(config)
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = tf.ones(input_values.shape, dtype=tf.int32)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = tf.random.uniform((input_values.shape[0],), maxval=len(model.config.id2label), dtype=tf.int32)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
training = False
|
||||
masked_loss = (
|
||||
model(input_values, attention_mask=attention_mask, labels=labels, training=training).loss.numpy().item()
|
||||
)
|
||||
unmasked_loss = model(input_values, labels=labels, training=training).loss.numpy().item()
|
||||
|
||||
assert isinstance(masked_loss, float)
|
||||
assert isinstance(unmasked_loss, float)
|
||||
assert masked_loss != unmasked_loss
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
model = TFWav2Vec2ForCTC(config)
|
||||
|
||||
@@ -286,8 +315,14 @@ class TFWav2Vec2ModelTester:
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": TFWav2Vec2Model} if is_tf_available() else {}
|
||||
all_model_classes = (
|
||||
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": TFWav2Vec2Model, "audio-classification": TFWav2Vec2ForSequenceClassification}
|
||||
if is_tf_available()
|
||||
else {}
|
||||
)
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
@@ -459,7 +494,9 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
||||
all_model_classes = (
|
||||
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
@@ -679,6 +716,11 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_ctc_normal(self):
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
@@ -791,3 +833,87 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
|
||||
|
||||
def test_inference_keyword_spotting(self):
|
||||
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True)
|
||||
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||
input_data = self._load_superb("ks", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
outputs = model(input_values, attention_mask)
|
||||
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
|
||||
outputs.logits, axis=-1
|
||||
)
|
||||
expected_labels = [7, 6, 10, 9]
|
||||
expected_logits = tf.convert_to_tensor([6.1186, 11.8961, 10.2931, 6.0898])
|
||||
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_intent_classification(self):
|
||||
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic", from_pt=True)
|
||||
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
|
||||
input_data = self._load_superb("ic", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits_action, predicted_ids_action = tf.math.reduce_max(outputs.logits[:, :6], axis=-1), tf.argmax(
|
||||
outputs.logits[:, :6], axis=-1
|
||||
)
|
||||
predicted_logits_object, predicted_ids_object = tf.math.reduce_max(
|
||||
outputs.logits[:, 6:20], axis=-1
|
||||
), tf.argmax(outputs.logits[:, 6:20], axis=-1)
|
||||
predicted_logits_location, predicted_ids_location = tf.math.reduce_max(
|
||||
outputs.logits[:, 20:24], axis=-1
|
||||
), tf.argmax(outputs.logits[:, 20:24], axis=-1)
|
||||
expected_labels_action = [0, 0, 2, 3]
|
||||
expected_logits_action = tf.convert_to_tensor([0.4568, 11.0848, 1.6621, 9.3841])
|
||||
expected_labels_object = [3, 10, 3, 4]
|
||||
expected_logits_object = tf.convert_to_tensor([1.5322, 10.7094, 5.2469, 22.1318])
|
||||
expected_labels_location = [0, 0, 0, 1]
|
||||
expected_logits_location = tf.convert_to_tensor([1.5335, 6.5096, 10.5704, 11.0569])
|
||||
|
||||
self.assertListEqual(predicted_ids_action.numpy().tolist(), expected_labels_action)
|
||||
self.assertListEqual(predicted_ids_object.numpy().tolist(), expected_labels_object)
|
||||
self.assertListEqual(predicted_ids_location.numpy().tolist(), expected_labels_location)
|
||||
|
||||
self.assertTrue(np.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
|
||||
self.assertTrue(np.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
|
||||
self.assertTrue(np.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_identification(self):
|
||||
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid", from_pt=True)
|
||||
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
|
||||
input_data = self._load_superb("si", 4)
|
||||
output_logits = []
|
||||
for example in input_data["speech"]:
|
||||
input = processor(example, return_tensors="tf", padding=True)
|
||||
output = model(input.input_values, attention_mask=None)
|
||||
output_logits.append(output.logits[0])
|
||||
output_logits = tf.stack(output_logits)
|
||||
predicted_logits, predicted_ids = tf.math.reduce_max(output_logits, axis=-1), tf.argmax(output_logits, axis=-1)
|
||||
expected_labels = [251, 1, 1, 3]
|
||||
expected_logits = tf.convert_to_tensor([37.5627, 71.6362, 64.2419, 31.7778])
|
||||
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_emotion_recognition(self):
|
||||
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er", from_pt=True)
|
||||
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
|
||||
input_data = self._load_superb("er", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
|
||||
outputs.logits, axis=-1
|
||||
)
|
||||
|
||||
expected_labels = [1, 1, 2, 2]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = tf.convert_to_tensor([2.1722, 3.0779, 8.0287, 6.6797])
|
||||
|
||||
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
Reference in New Issue
Block a user