add return_token_timestamps to WhisperProcessor (#30812)
* compute num_frames in WhisperFeatureExtractor * add return_num_frames in WhisperFeatureProcessor + adapt pipeline * return_timestamps renaming + pipeline fix * fix * fix * fix * add tests * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review changes * fix * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review * fix * review changes * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style quality * EXPECTED_OUTPUT in single line * small numpy->torch fix * fix --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -188,6 +188,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
do_normalize: Optional[bool] = None,
|
do_normalize: Optional[bool] = None,
|
||||||
device: Optional[str] = "cpu",
|
device: Optional[str] = "cpu",
|
||||||
|
return_token_timestamps: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
@@ -237,6 +238,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
device (`str`, *optional*, defaults to `'cpu'`):
|
device (`str`, *optional*, defaults to `'cpu'`):
|
||||||
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
||||||
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
||||||
|
return_token_timestamps (`bool`, *optional*, defaults to `None`):
|
||||||
|
Whether or not to return the number of frames of the input raw_speech.
|
||||||
|
These num_frames can be used by the model to compute word level timestamps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if sampling_rate is not None:
|
if sampling_rate is not None:
|
||||||
@@ -302,6 +306,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
if isinstance(input_features[0], List):
|
if isinstance(input_features[0], List):
|
||||||
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
padded_inputs["input_features"] = input_features
|
padded_inputs["input_features"] = input_features
|
||||||
|
|
||||||
@@ -309,6 +314,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
# rescale from sample (48000) to feature (3000)
|
# rescale from sample (48000) to feature (3000)
|
||||||
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
||||||
|
|
||||||
|
if return_token_timestamps is not None:
|
||||||
|
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
|
||||||
|
|
||||||
if return_tensors is not None:
|
if return_tensors is not None:
|
||||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||||
|
|
||||||
|
|||||||
@@ -209,11 +209,15 @@ class WhisperGenerationMixin:
|
|||||||
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
|
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
|
||||||
|
|
||||||
# we're using np.unique because num_frames can be int/list/tuple
|
# we're using np.unique because num_frames can be int/list/tuple
|
||||||
if len(np.unique(num_frames)) == 1:
|
if isinstance(num_frames, int):
|
||||||
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
|
|
||||||
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]
|
|
||||||
|
|
||||||
weights = weights[..., : num_frames // 2]
|
weights = weights[..., : num_frames // 2]
|
||||||
|
|
||||||
|
elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
|
||||||
|
weights = weights[..., : num_frames[0] // 2]
|
||||||
|
|
||||||
|
elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
|
||||||
|
weights = weights[..., : num_frames[0] // 2]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
||||||
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
||||||
@@ -231,7 +235,7 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
# Perform dynamic time warping on each element of the batch.
|
# Perform dynamic time warping on each element of the batch.
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
|
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
|
||||||
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
|
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
|
||||||
|
|
||||||
# Normalize and smoothen the weights.
|
# Normalize and smoothen the weights.
|
||||||
@@ -475,6 +479,7 @@ class WhisperGenerationMixin:
|
|||||||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. prepare generation config
|
# 1. prepare generation config
|
||||||
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -442,12 +442,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
padding="longest",
|
padding="longest",
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if self.type == "seq2seq_whisper" and stride is None:
|
||||||
|
processed = self.feature_extractor(
|
||||||
|
inputs,
|
||||||
|
sampling_rate=self.feature_extractor.sampling_rate,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_token_timestamps=True,
|
||||||
|
)
|
||||||
|
extra["num_frames"] = processed.pop("num_frames")
|
||||||
else:
|
else:
|
||||||
processed = self.feature_extractor(
|
processed = self.feature_extractor(
|
||||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||||
)
|
)
|
||||||
if stride is None:
|
|
||||||
extra["segment_size"] = len(inputs)
|
|
||||||
|
|
||||||
if self.torch_dtype is not None:
|
if self.torch_dtype is not None:
|
||||||
processed = processed.to(dtype=self.torch_dtype)
|
processed = processed.to(dtype=self.torch_dtype)
|
||||||
@@ -461,11 +468,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
|
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
|
||||||
attention_mask = model_inputs.pop("attention_mask", None)
|
attention_mask = model_inputs.pop("attention_mask", None)
|
||||||
stride = model_inputs.pop("stride", None)
|
stride = model_inputs.pop("stride", None)
|
||||||
segment_size = model_inputs.pop("segment_size", None)
|
num_frames = model_inputs.pop("num_frames", None)
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
|
|
||||||
if stride is not None and segment_size is not None:
|
if stride is not None and num_frames is not None:
|
||||||
raise ValueError("segment_size must be used only when stride is None")
|
raise ValueError("num_frames must be used only when stride is None")
|
||||||
|
|
||||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||||
encoder = self.model.get_encoder()
|
encoder = self.model.get_encoder()
|
||||||
@@ -495,10 +502,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if isinstance(segment_size, int):
|
generate_kwargs["num_frames"] = num_frames
|
||||||
generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length
|
|
||||||
else:
|
|
||||||
generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length
|
|
||||||
|
|
||||||
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
||||||
generate_kwargs["input_features"] = inputs
|
generate_kwargs["input_features"] = inputs
|
||||||
|
|||||||
@@ -1950,6 +1950,69 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_large_timestamp_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = np.concatenate(self._load_datasamples(4))
|
||||||
|
input_features = processor(
|
||||||
|
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
|
||||||
|
).input_features
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
|
||||||
|
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257])
|
||||||
|
# fmt: on
|
||||||
|
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPT = [
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||||
|
" Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive"
|
||||||
|
" season of the year, with Christmas and roast beef looming before us, similes drawn from eating"
|
||||||
|
" and its results occur most readily to the mind. He has grave doubts whether Sir Frederick "
|
||||||
|
"Leighton's work is really Greek after all,"
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||||
|
),
|
||||||
|
"timestamp": (0.0, 5.28),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||||
|
"timestamp": (6.34, 10.1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" He tells us that at this festive season of the year, with Christmas and roast beef looming before us,"
|
||||||
|
),
|
||||||
|
"timestamp": (10.92, 17.6),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": (" similes drawn from eating and its results occur most readily to the mind."),
|
||||||
|
"timestamp": (18.44, 22.580000000000002),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" He has grave doubts whether Sir Frederick Leighton's work is really Greek after all,"
|
||||||
|
),
|
||||||
|
"timestamp": (23.16, 28.68),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_token_timestamp_generation(self):
|
def test_tiny_token_timestamp_generation(self):
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
@@ -1979,6 +2042,36 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_large_token_timestamp_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(4)
|
||||||
|
input_features = processor(
|
||||||
|
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
|
||||||
|
)
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
|
||||||
|
generate_outputs = model.generate(
|
||||||
|
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000, 12.4600],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800, 9.8800]
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_token_timestamp_batch_generation(self):
|
def test_tiny_token_timestamp_batch_generation(self):
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user