From ea540a5977e602eb8072bfb3120c95f163c64a02 Mon Sep 17 00:00:00 2001 From: Arijit Mukherjee Date: Tue, 27 Sep 2022 16:42:56 +0530 Subject: [PATCH] add wav2vec2_alignment (#16782) * add wav2vec2_alignment * Update alignment.py * Update examples/research_projects/wav2vec2/alignment.py Co-authored-by: Patrick von Platen * Update examples/research_projects/wav2vec2/alignment.py Co-authored-by: Patrick von Platen * Update examples/research_projects/wav2vec2/alignment.py Co-authored-by: Patrick von Platen * Update examples/research_projects/wav2vec2/alignment.py Co-authored-by: Patrick von Platen * Update README.md * fix style * fix imports * fix multithread * fix bash script * [@anton-l] Style fixes and docstrings * [@anton-l] Style fixes and docstrings * Update alignment.py fix blank id in backtrack Co-authored-by: Patrick von Platen Co-authored-by: anton-l --- examples/research_projects/wav2vec2/README.md | 31 +++ .../research_projects/wav2vec2/alignment.py | 224 ++++++++++++++++++ .../wav2vec2/run_alignment.sh | 8 + 3 files changed, 263 insertions(+) create mode 100644 examples/research_projects/wav2vec2/alignment.py create mode 100644 examples/research_projects/wav2vec2/run_alignment.sh diff --git a/examples/research_projects/wav2vec2/README.md b/examples/research_projects/wav2vec2/README.md index 8f9da274f0..1dcd8dcc28 100644 --- a/examples/research_projects/wav2vec2/README.md +++ b/examples/research_projects/wav2vec2/README.md @@ -216,3 +216,34 @@ PYTHONPATH=../../../src deepspeed --num_gpus 4 run_pretrain.py \ --fp16 \ --deepspeed ds_config_wav2vec2_zero2.json \ ``` + + +### Forced Alignment + +Character level forced alignment for audio and text pairs with wav2vec2 models finetuned on ASR task for a specific language. +Inspired by [this](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) Pytorch tutorial. + +#### Input Formats + + Input format in script.txt Input format in wavs directroy + 0000 sentence1 0000.wav + 0001 sentence2 0001.wav + +#### Output Format + +Output directory will contain 0000.txt and 0001.txt. Each file will have format like below + + char score start_ms end_ms + h 0.25 1440 1520 + +#### Run command + +``` +python alignment.py \ +--model_name="arijitx/wav2vec2-xls-r-300m-bengali" \ +--wav_dir="./wavs" +--text_file="script.txt" \ +--input_wavs_sr=48000 \ +--output_dir="./out_alignment" \ +--cuda +``` diff --git a/examples/research_projects/wav2vec2/alignment.py b/examples/research_projects/wav2vec2/alignment.py new file mode 100644 index 0000000000..24347a55a0 --- /dev/null +++ b/examples/research_projects/wav2vec2/alignment.py @@ -0,0 +1,224 @@ +# Parts of the code are adapted from the snippets provided in the TorchAudio Wav2Vec forced alignment tutorial. +# The full tutorial can be found here: https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html + +import argparse +import os +from dataclasses import dataclass + +import torch +import torchaudio +from tqdm import tqdm + +from transformers import AutoConfig, AutoModelForCTC, AutoProcessor + + +class Wav2Vec2Aligner: + def __init__(self, model_name, input_wavs_sr, cuda): + self.cuda = cuda + self.config = AutoConfig.from_pretrained(model_name) + self.model = AutoModelForCTC.from_pretrained(model_name) + self.model.eval() + if self.cuda: + self.model.to(device="cuda") + self.processor = AutoProcessor.from_pretrained(model_name) + self.resampler = torchaudio.transforms.Resample(input_wavs_sr, 16_000) + blank_id = 0 + vocab = list(self.processor.tokenizer.get_vocab().keys()) + for i in range(len(vocab)): + if vocab[i] == "[PAD]" or vocab[i] == "": + blank_id = i + print("Blank Token id [PAD]/", blank_id) + self.blank_id = blank_id + + def speech_file_to_array_fn(self, wav_path): + speech_array, sampling_rate = torchaudio.load(wav_path) + speech = self.resampler(speech_array).squeeze().numpy() + return speech + + def align_single_sample(self, item): + blank_id = self.blank_id + transcript = "|".join(item["sent"].split(" ")) + if not os.path.isfile(item["wav_path"]): + print(item["wav_path"], "not found in wavs directory") + + speech_array = self.speech_file_to_array_fn(item["wav_path"]) + inputs = self.processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True) + if self.cuda: + inputs = inputs.to(device="cuda") + + with torch.no_grad(): + logits = self.model(inputs.input_values).logits + + # get the emission probability at frame level + emissions = torch.log_softmax(logits, dim=-1) + emission = emissions[0].cpu().detach() + + # get labels from vocab + labels = ([""] + list(self.processor.tokenizer.get_vocab().keys()))[ + :-1 + ] # logits don't align with the tokenizer's vocab + + dictionary = {c: i for i, c in enumerate(labels)} + tokens = [] + for c in transcript: + if c in dictionary: + tokens.append(dictionary[c]) + + def get_trellis(emission, tokens, blank_id=0): + """ + Build a trellis matrix of shape (num_frames + 1, num_tokens + 1) + that represents the probabilities of each source token being at a certain time step + """ + num_frames = emission.size(0) + num_tokens = len(tokens) + + # Trellis has extra diemsions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf")) + trellis[:, 0] = 0 + for t in range(num_frames): + trellis[t + 1, 1:] = torch.maximum( + # Score for staying at the same token + trellis[t, 1:] + emission[t, blank_id], + # Score for changing to the next token + trellis[t, :-1] + emission[t, tokens], + ) + return trellis + + trellis = get_trellis(emission, tokens, blank_id) + + @dataclass + class Point: + token_index: int + time_index: int + score: float + + def backtrack(trellis, emission, tokens, blank_id=0): + """ + Walk backwards from the last (sentence_token, time_step) pair to build the optimal sequence alignment path + """ + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): + # 1. Figure out if the current position was stay or change + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token + if changed > stayed: + j -= 1 + if j == 0: + break + else: + raise ValueError("Failed to align") + return path[::-1] + + path = backtrack(trellis, emission, tokens, blank_id) + + @dataclass + class Segment: + label: str + start: int + end: int + score: float + + def __repr__(self): + return f"{self.label}\t{self.score:4.2f}\t{self.start*20:5d}\t{self.end*20:5d}" + + @property + def length(self): + return self.end - self.start + + def merge_repeats(path): + """ + Merge repeated tokens into a single segment. Note: this shouldn't affect repeated characters from the + original sentences (e.g. `ll` in `hello`) + """ + i1, i2 = 0, 0 + segments = [] + while i1 < len(path): + while i2 < len(path) and path[i1].token_index == path[i2].token_index: + i2 += 1 + score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) + segments.append( + Segment( + transcript[path[i1].token_index], + path[i1].time_index, + path[i2 - 1].time_index + 1, + score, + ) + ) + i1 = i2 + return segments + + segments = merge_repeats(path) + with open(item["out_path"], "w") as out_align: + for seg in segments: + out_align.write(str(seg) + "\n") + + def align_data(self, wav_dir, text_file, output_dir): + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # load text file + lines = open(text_file, encoding="utf8").readlines() + + items = [] + for line in lines: + if len(line.strip().split("\t")) != 2: + print("Script must be in format: 00001 this is my sentence") + exit() + + wav_name, sentence = line.strip().split("\t") + wav_path = os.path.join(wav_dir, wav_name + ".wav") + out_path = os.path.join(output_dir, wav_name + ".txt") + + items.append({"sent": sentence, "wav_path": wav_path, "out_path": out_path}) + print("Number of samples found in script file", len(items)) + + for item in tqdm(items): + self.align_single_sample(item) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name", type=str, default="arijitx/wav2vec2-xls-r-300m-bengali", help="wav2vec model name" + ) + parser.add_argument("--wav_dir", type=str, default="./wavs", help="directory containing wavs") + parser.add_argument("--text_file", type=str, default="script.txt", help="file containing text") + parser.add_argument("--input_wavs_sr", type=int, default=16000, help="sampling rate of input audios") + parser.add_argument( + "--output_dir", type=str, default="./out_alignment", help="output directory containing the alignment files" + ) + parser.add_argument("--cuda", action="store_true") + + args = parser.parse_args() + + aligner = Wav2Vec2Aligner(args.model_name, args.input_wavs_sr, args.cuda) + aligner.align_data(args.wav_dir, args.text_file, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/wav2vec2/run_alignment.sh b/examples/research_projects/wav2vec2/run_alignment.sh new file mode 100644 index 0000000000..95bfe02cf0 --- /dev/null +++ b/examples/research_projects/wav2vec2/run_alignment.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +python alignment.py \ +--model_name="arijitx/wav2vec2-xls-r-300m-bengali" \ +--wav_dir="./wavs" \ +--text_file="script.txt" \ +--input_wavs_sr=48000 \ +--output_dir="./out_alignment" \ +--cuda