Add TensorFlow Wav2Vec2 for sequence classification (#22073)
* Add initial changes for TF wav2vec2 for sequence classification * Add suggested changes * Add serving and serving output methods * Add serving_output implementation and fix layer_weights * Add fixes * Fixed test cases * Fixing test and adding suggested changes
This commit is contained in:
@@ -197,6 +197,11 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower
|
|||||||
[[autodoc]] TFWav2Vec2Model
|
[[autodoc]] TFWav2Vec2Model
|
||||||
- call
|
- call
|
||||||
|
|
||||||
|
## TFWav2Vec2ForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] TFWav2Vec2ForSequenceClassification
|
||||||
|
- call
|
||||||
|
|
||||||
## TFWav2Vec2ForCTC
|
## TFWav2Vec2ForCTC
|
||||||
|
|
||||||
[[autodoc]] TFWav2Vec2ForCTC
|
[[autodoc]] TFWav2Vec2ForCTC
|
||||||
|
|||||||
@@ -3443,6 +3443,7 @@ else:
|
|||||||
[
|
[
|
||||||
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"TFWav2Vec2ForCTC",
|
"TFWav2Vec2ForCTC",
|
||||||
|
"TFWav2Vec2ForSequenceClassification",
|
||||||
"TFWav2Vec2Model",
|
"TFWav2Vec2Model",
|
||||||
"TFWav2Vec2PreTrainedModel",
|
"TFWav2Vec2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -6626,6 +6627,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.wav2vec2 import (
|
from .models.wav2vec2 import (
|
||||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFWav2Vec2ForCTC,
|
TFWav2Vec2ForCTC,
|
||||||
|
TFWav2Vec2ForSequenceClassification,
|
||||||
TFWav2Vec2Model,
|
TFWav2Vec2Model,
|
||||||
TFWav2Vec2PreTrainedModel,
|
TFWav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -351,6 +351,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
|
||||||
|
|
||||||
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
@@ -471,6 +472,9 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
|||||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModel(_BaseAutoModelClass):
|
class TFAutoModel(_BaseAutoModelClass):
|
||||||
@@ -480,6 +484,15 @@ class TFAutoModel(_BaseAutoModelClass):
|
|||||||
TFAutoModel = auto_class_update(TFAutoModel)
|
TFAutoModel = auto_class_update(TFAutoModel)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
|
||||||
|
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
TFAutoModelForAudioClassification = auto_class_update(
|
||||||
|
TFAutoModelForAudioClassification, head_doc="audio classification"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
||||||
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ else:
|
|||||||
"TFWav2Vec2ForCTC",
|
"TFWav2Vec2ForCTC",
|
||||||
"TFWav2Vec2Model",
|
"TFWav2Vec2Model",
|
||||||
"TFWav2Vec2PreTrainedModel",
|
"TFWav2Vec2PreTrainedModel",
|
||||||
|
"TFWav2Vec2ForSequenceClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -108,6 +109,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_tf_wav2vec2 import (
|
from .modeling_tf_wav2vec2 import (
|
||||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFWav2Vec2ForCTC,
|
TFWav2Vec2ForCTC,
|
||||||
|
TFWav2Vec2ForSequenceClassification,
|
||||||
TFWav2Vec2Model,
|
TFWav2Vec2Model,
|
||||||
TFWav2Vec2PreTrainedModel,
|
TFWav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
@@ -1212,6 +1212,46 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
|
|||||||
|
|
||||||
return self.serving_output(output)
|
return self.serving_output(output)
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers
|
||||||
|
"""
|
||||||
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
||||||
|
|
||||||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||||||
|
return tf.math.floordiv(input_length - kernel_size, stride) + 1
|
||||||
|
|
||||||
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||||
|
|
||||||
|
if add_adapter:
|
||||||
|
for _ in range(self.config.num_adapter_layers):
|
||||||
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
||||||
|
return input_lengths
|
||||||
|
|
||||||
|
def _get_feature_vector_attention_mask(
|
||||||
|
self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None
|
||||||
|
):
|
||||||
|
non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
|
||||||
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
||||||
|
output_lengths = tf.cast(output_lengths, tf.int32)
|
||||||
|
batch_size = tf.shape(attention_mask)[0]
|
||||||
|
# check device here
|
||||||
|
attention_mask = tf.zeros(
|
||||||
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
|
||||||
|
) # these two operations makes sure that all values before the output lengths idxs are attended to
|
||||||
|
## check device
|
||||||
|
attention_mask = tf.tensor_scatter_nd_update(
|
||||||
|
attention_mask,
|
||||||
|
indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1),
|
||||||
|
updates=tf.ones([batch_size], dtype=attention_mask.dtype),
|
||||||
|
)
|
||||||
|
attention_mask = tf.reverse(attention_mask, axis=[-1])
|
||||||
|
attention_mask = tf.cumsum(attention_mask, axis=-1)
|
||||||
|
attention_mask = tf.reverse(attention_mask, axis=[-1])
|
||||||
|
attention_mask = tf.cast(attention_mask, tf.bool)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@@ -1552,3 +1592,125 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||||||
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
|
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
|
||||||
|
self.num_layers = config.num_hidden_layers + 1
|
||||||
|
with tf.name_scope(self._name_scope()):
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = self.add_weight(
|
||||||
|
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector")
|
||||||
|
self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
||||||
|
not be updated during training.
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
|
||||||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.freeze_feature_encoder()
|
||||||
|
|
||||||
|
def freeze_feature_encoder(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||||||
|
not be updated during training.
|
||||||
|
"""
|
||||||
|
self.wav2vec2.feature_extractor.trainable = False
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for layer in self.wav2vec2.layers:
|
||||||
|
layer.trainable = False
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_values: tf.Tensor,
|
||||||
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
labels: Optional[tf.Tensor] = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = tf.stack(hidden_states, axis=1)
|
||||||
|
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
|
||||||
|
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
if attention_mask is None:
|
||||||
|
pooled_output = tf.reduce_mean(hidden_states, axis=1)
|
||||||
|
else:
|
||||||
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||||
|
padding_mask_float = tf.cast(padding_mask, hidden_states.dtype)
|
||||||
|
hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1))
|
||||||
|
pooled_output = tf.divide(
|
||||||
|
tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1)
|
||||||
|
)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TFSequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serving_output(self, output):
|
||||||
|
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||||
|
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
|
return TFSequenceClassifierOutput(
|
||||||
|
logits=output.logits,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attentions=attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tf.function(
|
||||||
|
input_signature=[
|
||||||
|
{
|
||||||
|
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
|
||||||
|
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
|
||||||
|
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def serving(self, inputs):
|
||||||
|
output = self.call(input_values=inputs)
|
||||||
|
|
||||||
|
return self.serving_output(output)
|
||||||
|
|||||||
@@ -2590,6 +2590,13 @@ class TFWav2Vec2ForCTC(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFWav2Vec2ForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFWav2Vec2Model(metaclass=DummyObject):
|
class TFWav2Vec2Model(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor
|
from transformers import (
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
TFWav2Vec2ForCTC,
|
||||||
|
TFWav2Vec2ForSequenceClassification,
|
||||||
|
TFWav2Vec2Model,
|
||||||
|
Wav2Vec2Processor,
|
||||||
|
)
|
||||||
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
|
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -247,6 +253,29 @@ class TFWav2Vec2ModelTester:
|
|||||||
|
|
||||||
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
|
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
|
||||||
|
|
||||||
|
def check_seq_classifier_loss(self, loss, config, input_values, *args):
|
||||||
|
model = TFWav2Vec2ForSequenceClassification(config)
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
attention_mask = tf.ones(input_values.shape, dtype=tf.int32)
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = tf.random.uniform((input_values.shape[0],), maxval=len(model.config.id2label), dtype=tf.int32)
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
attention_mask[i, input_lengths[i] :] = 0
|
||||||
|
training = False
|
||||||
|
masked_loss = (
|
||||||
|
model(input_values, attention_mask=attention_mask, labels=labels, training=training).loss.numpy().item()
|
||||||
|
)
|
||||||
|
unmasked_loss = model(input_values, labels=labels, training=training).loss.numpy().item()
|
||||||
|
|
||||||
|
assert isinstance(masked_loss, float)
|
||||||
|
assert isinstance(unmasked_loss, float)
|
||||||
|
assert masked_loss != unmasked_loss
|
||||||
|
|
||||||
def check_training(self, config, input_values, *args):
|
def check_training(self, config, input_values, *args):
|
||||||
model = TFWav2Vec2ForCTC(config)
|
model = TFWav2Vec2ForCTC(config)
|
||||||
|
|
||||||
@@ -286,8 +315,14 @@ class TFWav2Vec2ModelTester:
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
all_model_classes = (
|
||||||
pipeline_model_mapping = {"feature-extraction": TFWav2Vec2Model} if is_tf_available() else {}
|
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
|
||||||
|
)
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": TFWav2Vec2Model, "audio-classification": TFWav2Vec2ForSequenceClassification}
|
||||||
|
if is_tf_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
@@ -459,7 +494,9 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
all_model_classes = (
|
||||||
|
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
|
||||||
|
)
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
@@ -679,6 +716,11 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
return [x["array"] for x in speech_samples]
|
return [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
|
def _load_superb(self, task, num_samples):
|
||||||
|
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||||
|
|
||||||
|
return ds[:num_samples]
|
||||||
|
|
||||||
def test_inference_ctc_normal(self):
|
def test_inference_ctc_normal(self):
|
||||||
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
@@ -791,3 +833,87 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
@require_librosa
|
@require_librosa
|
||||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
|
run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
|
||||||
|
|
||||||
|
def test_inference_keyword_spotting(self):
|
||||||
|
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True)
|
||||||
|
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||||
|
input_data = self._load_superb("ks", 4)
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||||
|
input_values = inputs.input_values
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
outputs = model(input_values, attention_mask)
|
||||||
|
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
|
||||||
|
outputs.logits, axis=-1
|
||||||
|
)
|
||||||
|
expected_labels = [7, 6, 10, 9]
|
||||||
|
expected_logits = tf.convert_to_tensor([6.1186, 11.8961, 10.2931, 6.0898])
|
||||||
|
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||||
|
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||||
|
|
||||||
|
def test_inference_intent_classification(self):
|
||||||
|
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic", from_pt=True)
|
||||||
|
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
|
||||||
|
input_data = self._load_superb("ic", 4)
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||||
|
input_values = inputs.input_values
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask)
|
||||||
|
predicted_logits_action, predicted_ids_action = tf.math.reduce_max(outputs.logits[:, :6], axis=-1), tf.argmax(
|
||||||
|
outputs.logits[:, :6], axis=-1
|
||||||
|
)
|
||||||
|
predicted_logits_object, predicted_ids_object = tf.math.reduce_max(
|
||||||
|
outputs.logits[:, 6:20], axis=-1
|
||||||
|
), tf.argmax(outputs.logits[:, 6:20], axis=-1)
|
||||||
|
predicted_logits_location, predicted_ids_location = tf.math.reduce_max(
|
||||||
|
outputs.logits[:, 20:24], axis=-1
|
||||||
|
), tf.argmax(outputs.logits[:, 20:24], axis=-1)
|
||||||
|
expected_labels_action = [0, 0, 2, 3]
|
||||||
|
expected_logits_action = tf.convert_to_tensor([0.4568, 11.0848, 1.6621, 9.3841])
|
||||||
|
expected_labels_object = [3, 10, 3, 4]
|
||||||
|
expected_logits_object = tf.convert_to_tensor([1.5322, 10.7094, 5.2469, 22.1318])
|
||||||
|
expected_labels_location = [0, 0, 0, 1]
|
||||||
|
expected_logits_location = tf.convert_to_tensor([1.5335, 6.5096, 10.5704, 11.0569])
|
||||||
|
|
||||||
|
self.assertListEqual(predicted_ids_action.numpy().tolist(), expected_labels_action)
|
||||||
|
self.assertListEqual(predicted_ids_object.numpy().tolist(), expected_labels_object)
|
||||||
|
self.assertListEqual(predicted_ids_location.numpy().tolist(), expected_labels_location)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
|
||||||
|
self.assertTrue(np.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
|
||||||
|
self.assertTrue(np.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
|
||||||
|
|
||||||
|
def test_inference_speaker_identification(self):
|
||||||
|
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid", from_pt=True)
|
||||||
|
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
|
||||||
|
input_data = self._load_superb("si", 4)
|
||||||
|
output_logits = []
|
||||||
|
for example in input_data["speech"]:
|
||||||
|
input = processor(example, return_tensors="tf", padding=True)
|
||||||
|
output = model(input.input_values, attention_mask=None)
|
||||||
|
output_logits.append(output.logits[0])
|
||||||
|
output_logits = tf.stack(output_logits)
|
||||||
|
predicted_logits, predicted_ids = tf.math.reduce_max(output_logits, axis=-1), tf.argmax(output_logits, axis=-1)
|
||||||
|
expected_labels = [251, 1, 1, 3]
|
||||||
|
expected_logits = tf.convert_to_tensor([37.5627, 71.6362, 64.2419, 31.7778])
|
||||||
|
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||||
|
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||||
|
|
||||||
|
def test_inference_emotion_recognition(self):
|
||||||
|
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er", from_pt=True)
|
||||||
|
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
|
||||||
|
input_data = self._load_superb("er", 4)
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
|
||||||
|
|
||||||
|
input_values = inputs.input_values
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask)
|
||||||
|
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
|
||||||
|
outputs.logits, axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_labels = [1, 1, 2, 2]
|
||||||
|
# s3prl logits for the same batch
|
||||||
|
expected_logits = tf.convert_to_tensor([2.1722, 3.0779, 8.0287, 6.6797])
|
||||||
|
|
||||||
|
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
|
||||||
|
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||||
|
|||||||
Reference in New Issue
Block a user