Add TensorFlow Wav2Vec2 for sequence classification (#22073)

* Add initial changes for TF wav2vec2 for sequence classification

* Add suggested changes

* Add serving and serving output methods

* Add serving_output implementation and fix layer_weights

* Add fixes

* Fixed test cases

* Fixing test and adding suggested changes
This commit is contained in:
Ritik Nandwal
2023-04-26 18:05:30 +05:30
committed by GitHub
parent 4c2b4c4c3c
commit 20ac86c6f1
7 changed files with 322 additions and 5 deletions

View File

@@ -50,7 +50,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor
from transformers import (
AutoFeatureExtractor,
TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
@@ -247,6 +253,29 @@ class TFWav2Vec2ModelTester:
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
def check_seq_classifier_loss(self, loss, config, input_values, *args):
model = TFWav2Vec2ForSequenceClassification(config)
input_values = input_values[:3]
attention_mask = tf.ones(input_values.shape, dtype=tf.int32)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = tf.random.uniform((input_values.shape[0],), maxval=len(model.config.id2label), dtype=tf.int32)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
training = False
masked_loss = (
model(input_values, attention_mask=attention_mask, labels=labels, training=training).loss.numpy().item()
)
unmasked_loss = model(input_values, labels=labels, training=training).loss.numpy().item()
assert isinstance(masked_loss, float)
assert isinstance(unmasked_loss, float)
assert masked_loss != unmasked_loss
def check_training(self, config, input_values, *args):
model = TFWav2Vec2ForCTC(config)
@@ -286,8 +315,14 @@ class TFWav2Vec2ModelTester:
@require_tf
class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFWav2Vec2Model} if is_tf_available() else {}
all_model_classes = (
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
)
pipeline_model_mapping = (
{"feature-extraction": TFWav2Vec2Model, "audio-classification": TFWav2Vec2ForSequenceClassification}
if is_tf_available()
else {}
)
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
@@ -459,7 +494,9 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
@require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
all_model_classes = (
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
)
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
@@ -679,6 +716,11 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples]
def _load_superb(self, task, num_samples):
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_normal(self):
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
@@ -791,3 +833,87 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_librosa
def test_wav2vec2_with_lm_invalid_pool(self):
run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
def test_inference_keyword_spotting(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask)
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
outputs.logits, axis=-1
)
expected_labels = [7, 6, 10, 9]
expected_logits = tf.convert_to_tensor([6.1186, 11.8961, 10.2931, 6.0898])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = tf.math.reduce_max(outputs.logits[:, :6], axis=-1), tf.argmax(
outputs.logits[:, :6], axis=-1
)
predicted_logits_object, predicted_ids_object = tf.math.reduce_max(
outputs.logits[:, 6:20], axis=-1
), tf.argmax(outputs.logits[:, 6:20], axis=-1)
predicted_logits_location, predicted_ids_location = tf.math.reduce_max(
outputs.logits[:, 20:24], axis=-1
), tf.argmax(outputs.logits[:, 20:24], axis=-1)
expected_labels_action = [0, 0, 2, 3]
expected_logits_action = tf.convert_to_tensor([0.4568, 11.0848, 1.6621, 9.3841])
expected_labels_object = [3, 10, 3, 4]
expected_logits_object = tf.convert_to_tensor([1.5322, 10.7094, 5.2469, 22.1318])
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = tf.convert_to_tensor([1.5335, 6.5096, 10.5704, 11.0569])
self.assertListEqual(predicted_ids_action.numpy().tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.numpy().tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.numpy().tolist(), expected_labels_location)
self.assertTrue(np.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
self.assertTrue(np.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
self.assertTrue(np.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
def test_inference_speaker_identification(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
for example in input_data["speech"]:
input = processor(example, return_tensors="tf", padding=True)
output = model(input.input_values, attention_mask=None)
output_logits.append(output.logits[0])
output_logits = tf.stack(output_logits)
predicted_logits, predicted_ids = tf.math.reduce_max(output_logits, axis=-1), tf.argmax(output_logits, axis=-1)
expected_labels = [251, 1, 1, 3]
expected_logits = tf.convert_to_tensor([37.5627, 71.6362, 64.2419, 31.7778])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_emotion_recognition(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
outputs.logits, axis=-1
)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = tf.convert_to_tensor([2.1722, 3.0779, 8.0287, 6.6797])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))