[DocTests Speech] Add doc tests for all speech models (#15031)
* fix_torch_device_generate_test * remove @ * doc tests * up * up * fix doctests * adapt files * finish refactor * up * save intermediate * add more logic * new change * improve * next try * next try * next try * next try * fix final spaces * fix final spaces * improve * renaming * correct more bugs * finish wavlm * add comment * run on test runner * finish all speech models * adapt * finish
This commit is contained in:
committed by
GitHub
parent
4df69506a8
commit
9f831bdeaf
14
.github/workflows/doctests.yml
vendored
14
.github/workflows/doctests.yml
vendored
@@ -19,7 +19,7 @@ env:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_doctests:
|
run_doctests:
|
||||||
runs-on: [self-hosted, docker-gpu, single-gpu]
|
runs-on: [self-hosted, docker-gpu-test, single-gpu]
|
||||||
container:
|
container:
|
||||||
image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
|
image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
|
||||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||||
@@ -35,8 +35,16 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
apt -y update && apt install -y libsndfile1-dev
|
apt -y update && apt install -y libsndfile1-dev
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install .[dev]
|
pip install .[testing,torch-speech]
|
||||||
|
|
||||||
|
- name: Prepare files for doctests
|
||||||
|
run: |
|
||||||
|
python utils/prepare_for_doc_test.py src docs
|
||||||
|
|
||||||
- name: Run doctests
|
- name: Run doctests
|
||||||
run: |
|
run: |
|
||||||
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure
|
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"
|
||||||
|
|
||||||
|
- name: Clean files after doctests
|
||||||
|
run: |
|
||||||
|
python utils/prepare_for_doc_test.py src docs --remove_new_line
|
||||||
|
|||||||
@@ -1127,9 +1127,11 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import {processor_class}, {model_class}
|
>>> from transformers import {processor_class}, {model_class}
|
||||||
|
>>> import torch
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.sort("id")
|
||||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
||||||
@@ -1137,9 +1139,12 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
|
|||||||
|
|
||||||
>>> # audio file is decoded on the fly
|
>>> # audio file is decoded on the fly
|
||||||
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> with torch.no_grad():
|
||||||
|
... outputs = model(**inputs)
|
||||||
|
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
>>> list(last_hidden_states.shape)
|
||||||
|
{expected_output}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1152,6 +1157,7 @@ PT_SPEECH_CTC_SAMPLE = r"""
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.sort("id")
|
||||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
||||||
@@ -1159,17 +1165,24 @@ PT_SPEECH_CTC_SAMPLE = r"""
|
|||||||
|
|
||||||
>>> # audio file is decoded on the fly
|
>>> # audio file is decoded on the fly
|
||||||
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||||
>>> logits = model(**inputs).logits
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
>>> # transcribe speech
|
>>> # transcribe speech
|
||||||
>>> transcription = processor.batch_decode(predicted_ids)
|
>>> transcription = processor.batch_decode(predicted_ids)
|
||||||
|
>>> transcription[0]
|
||||||
|
{expected_output}
|
||||||
|
```
|
||||||
|
|
||||||
>>> # compute loss
|
```python
|
||||||
>>> with processor.as_target_processor():
|
>>> with processor.as_target_processor():
|
||||||
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
|
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
>>> # compute loss
|
||||||
>>> loss = model(**inputs).loss
|
>>> loss = model(**inputs).loss
|
||||||
|
>>> round(loss.item(), 2)
|
||||||
|
{expected_loss}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1182,21 +1195,31 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.sort("id")
|
||||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||||
|
|
||||||
>>> # audio file is decoded on the fly
|
>>> # audio file is decoded on the fly
|
||||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||||
>>> logits = model(**inputs).logits
|
|
||||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
|
||||||
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
|
||||||
|
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
|
||||||
|
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
||||||
|
>>> predicted_label
|
||||||
|
{expected_output}
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
>>> # compute loss - target_label is e.g. "down"
|
>>> # compute loss - target_label is e.g. "down"
|
||||||
>>> target_label = model.config.id2label[0]
|
>>> target_label = model.config.id2label[0]
|
||||||
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
|
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
|
||||||
>>> loss = model(**inputs).loss
|
>>> loss = model(**inputs).loss
|
||||||
|
>>> round(loss.item(), 2)
|
||||||
|
{expected_loss}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1210,17 +1233,22 @@ PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.sort("id")
|
||||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||||
|
|
||||||
>>> # audio file is decoded on the fly
|
>>> # audio file is decoded on the fly
|
||||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
|
||||||
>>> logits = model(**inputs).logits
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
>>> probabilities = torch.sigmoid(logits[0])
|
>>> probabilities = torch.sigmoid(logits[0])
|
||||||
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
||||||
>>> labels = (probabilities > 0.5).long()
|
>>> labels = (probabilities > 0.5).long()
|
||||||
|
>>> labels[0].tolist()
|
||||||
|
{expected_output}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1234,14 +1262,19 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.sort("id")
|
||||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||||
|
|
||||||
>>> # audio file is decoded on the fly
|
>>> # audio file is decoded on the fly
|
||||||
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
|
>>> inputs = feature_extractor(
|
||||||
>>> embeddings = model(**inputs).embeddings
|
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
|
||||||
|
... )
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... embeddings = model(**inputs).embeddings
|
||||||
|
|
||||||
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
||||||
|
|
||||||
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
||||||
@@ -1250,6 +1283,8 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
|
|||||||
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
||||||
>>> if similarity < threshold:
|
>>> if similarity < threshold:
|
||||||
... print("Speakers are not the same!")
|
... print("Speakers are not the same!")
|
||||||
|
>>> round(similarity.item(), 2)
|
||||||
|
{expected_output}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1553,9 +1588,11 @@ def add_code_sample_docstrings(
|
|||||||
checkpoint=None,
|
checkpoint=None,
|
||||||
output_type=None,
|
output_type=None,
|
||||||
config_class=None,
|
config_class=None,
|
||||||
mask=None,
|
mask="[MASK]",
|
||||||
model_cls=None,
|
model_cls=None,
|
||||||
modality=None
|
modality=None,
|
||||||
|
expected_output="",
|
||||||
|
expected_loss="",
|
||||||
):
|
):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
# model_class defaults to function's class if not specified otherwise
|
# model_class defaults to function's class if not specified otherwise
|
||||||
@@ -1568,7 +1605,17 @@ def add_code_sample_docstrings(
|
|||||||
else:
|
else:
|
||||||
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
||||||
|
|
||||||
doc_kwargs = dict(model_class=model_class, processor_class=processor_class, checkpoint=checkpoint)
|
# putting all kwargs for docstrings in a dict to be used
|
||||||
|
# with the `.format(**doc_kwargs)`. Note that string might
|
||||||
|
# be formatted with non-existing keys, which is fine.
|
||||||
|
doc_kwargs = dict(
|
||||||
|
model_class=model_class,
|
||||||
|
processor_class=processor_class,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
mask=mask,
|
||||||
|
expected_output=expected_output,
|
||||||
|
expected_loss=expected_loss,
|
||||||
|
)
|
||||||
|
|
||||||
if "SequenceClassification" in model_class and modality == "audio":
|
if "SequenceClassification" in model_class and modality == "audio":
|
||||||
code_sample = sample_docstrings["AudioClassification"]
|
code_sample = sample_docstrings["AudioClassification"]
|
||||||
@@ -1581,7 +1628,6 @@ def add_code_sample_docstrings(
|
|||||||
elif "MultipleChoice" in model_class:
|
elif "MultipleChoice" in model_class:
|
||||||
code_sample = sample_docstrings["MultipleChoice"]
|
code_sample = sample_docstrings["MultipleChoice"]
|
||||||
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
||||||
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
|
||||||
code_sample = sample_docstrings["MaskedLM"]
|
code_sample = sample_docstrings["MaskedLM"]
|
||||||
elif "LMHead" in model_class or "CausalLM" in model_class:
|
elif "LMHead" in model_class or "CausalLM" in model_class:
|
||||||
code_sample = sample_docstrings["LMHead"]
|
code_sample = sample_docstrings["LMHead"]
|
||||||
|
|||||||
@@ -40,15 +40,29 @@ from .configuration_hubert import HubertConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "HubertConfig"
|
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "HubertConfig"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||||
|
_CTC_EXPECTED_LOSS = 22.68
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 8.53
|
||||||
|
|
||||||
|
|
||||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/hubert-base-ls960",
|
"facebook/hubert-base-ls960",
|
||||||
@@ -1098,6 +1112,8 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1228,6 +1244,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -36,16 +36,33 @@ from .configuration_sew import SEWConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "SEWConfig"
|
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "SEWConfig"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = (
|
||||||
|
"'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
|
||||||
|
)
|
||||||
|
_CTC_EXPECTED_LOSS = 0.42
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 9.52
|
||||||
|
|
||||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"asapp/sew-tiny-100k",
|
"asapp/sew-tiny-100k",
|
||||||
"asapp/sew-small-100k",
|
"asapp/sew-small-100k",
|
||||||
@@ -879,6 +896,7 @@ class SEWModel(SEWPreTrainedModel):
|
|||||||
output_type=BaseModelOutput,
|
output_type=BaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -978,6 +996,8 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1108,6 +1128,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -37,15 +37,28 @@ from .configuration_sew_d import SEWDConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||||
|
_CTC_EXPECTED_LOSS = 0.21
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 3.16
|
||||||
|
|
||||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"asapp/sew-d-tiny-100k",
|
"asapp/sew-d-tiny-100k",
|
||||||
"asapp/sew-d-small-100k",
|
"asapp/sew-d-small-100k",
|
||||||
@@ -1415,6 +1428,7 @@ class SEWDModel(SEWDPreTrainedModel):
|
|||||||
output_type=BaseModelOutput,
|
output_type=BaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1514,6 +1528,8 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1644,6 +1660,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -42,15 +42,27 @@ from .configuration_unispeech import UniSpeechConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
|
||||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
|
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
|
||||||
|
_CTC_EXPECTED_LOSS = 17.17
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 0.66 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
|
||||||
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"microsoft/unispeech-large-1500h-cv",
|
"microsoft/unispeech-large-1500h-cv",
|
||||||
"microsoft/unispeech-large-multi-lingual-1500h-cv",
|
"microsoft/unispeech-large-multi-lingual-1500h-cv",
|
||||||
@@ -1129,6 +1141,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
|
|||||||
output_type=UniSpeechBaseModelOutput,
|
output_type=UniSpeechBaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1266,44 +1279,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
|
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechForPreTraining
|
||||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
>>> from transformers.models.unispeech.modeling_unispeech import _compute_mask_indices
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
... "hf-internal-testing/tiny-random-unispeech-sat"
|
||||||
|
... )
|
||||||
|
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
|
||||||
>>> def map_to_array(batch):
|
>>> # TODO: Add full pretraining example
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
|
|
||||||
>>> # compute masked indices
|
|
||||||
>>> batch_size, raw_sequence_length = input_values.shape
|
|
||||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
|
||||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
|
||||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
|
||||||
|
|
||||||
>>> with torch.no_grad():
|
|
||||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
|
||||||
|
|
||||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
|
||||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
|
||||||
|
|
||||||
>>> # show that cosine similarity is much higher than random
|
|
||||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
|
||||||
|
|
||||||
>>> # for contrastive loss training model should be put into train mode
|
|
||||||
>>> model.train()
|
|
||||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1406,6 +1389,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1536,6 +1521,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -43,16 +43,33 @@ from .configuration_unispeech_sat import UniSpeechSatConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
# General docstring
|
||||||
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||||
|
_CTC_EXPECTED_LOSS = 39.88
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech-sat"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 0.71 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
|
# Frame class docstring
|
||||||
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
|
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
|
||||||
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
# Speaker Verification docstring
|
||||||
|
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
||||||
|
_XVECTOR_EXPECTED_OUTPUT = 0.97
|
||||||
|
|
||||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
|
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
|
||||||
@@ -1163,6 +1180,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
|
|||||||
output_type=UniSpeechSatBaseModelOutput,
|
output_type=UniSpeechSatBaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1300,42 +1318,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining
|
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining
|
||||||
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
|
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
|
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
|
||||||
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
|
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
|
||||||
|
>>> # TODO: Add full pretraining example
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
|
|
||||||
>>> # compute masked indices
|
|
||||||
>>> batch_size, raw_sequence_length = input_values.shape
|
|
||||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
|
||||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
|
||||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
|
||||||
|
|
||||||
>>> with torch.no_grad():
|
|
||||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
|
||||||
|
|
||||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
|
||||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
|
||||||
|
|
||||||
>>> # show that cosine similarity is much higher than random
|
|
||||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
|
||||||
|
|
||||||
>>> # for contrastive loss training model should be put into train mode
|
|
||||||
>>> model.train()
|
|
||||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1431,6 +1417,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1561,6 +1549,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1677,6 +1667,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
|
|||||||
output_type=TokenClassifierOutput,
|
output_type=TokenClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1853,6 +1844,7 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
|
|||||||
output_type=XVectorOutput,
|
output_type=XVectorOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -48,17 +48,35 @@ from .configuration_wav2vec2 import Wav2Vec2Config
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
|
||||||
_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd"
|
|
||||||
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||||
|
_CTC_EXPECTED_LOSS = 53.48
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 6.54
|
||||||
|
|
||||||
|
# Frame class docstring
|
||||||
|
_FRAME_CLASS_CHECKPOINT = "anton-l/wav2vec2-base-superb-sd"
|
||||||
|
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||||
|
|
||||||
|
# Speaker Verification docstring
|
||||||
|
_XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv"
|
||||||
|
_XVECTOR_EXPECTED_OUTPUT = 0.98
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/wav2vec2-base-960h",
|
"facebook/wav2vec2-base-960h",
|
||||||
@@ -1294,6 +1312,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
output_type=Wav2Vec2BaseModelOutput,
|
output_type=Wav2Vec2BaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1469,10 +1488,11 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||||
|
|
||||||
>>> # show that cosine similarity is much higher than random
|
>>> # show that cosine similarity is much higher than random
|
||||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
||||||
|
tensor(True)
|
||||||
|
|
||||||
>>> # for contrastive loss training model should be put into train mode
|
>>> # for contrastive loss training model should be put into train mode
|
||||||
>>> model.train()
|
>>> model = model.train()
|
||||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
@@ -1697,6 +1717,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1826,6 +1848,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1941,6 +1965,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
|
|||||||
output_type=TokenClassifierOutput,
|
output_type=TokenClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -2114,6 +2139,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
|
|||||||
output_type=XVectorOutput,
|
output_type=XVectorOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -42,19 +42,35 @@ from .configuration_wavlm import WavLMConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "WavLMConfig"
|
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
|
||||||
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
|
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base"
|
|
||||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus"
|
|
||||||
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
|
||||||
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
# General docstring
|
||||||
|
_CONFIG_FOR_DOC = "WavLMConfig"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||||
|
|
||||||
|
# CTC docstring
|
||||||
|
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
|
||||||
|
_CTC_EXPECTED_LOSS = 12.51
|
||||||
|
|
||||||
|
# Audio class docstring
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-wavlm"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'no'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 0.7 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||||
|
|
||||||
|
# Frame class docstring
|
||||||
|
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
||||||
|
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||||
|
|
||||||
|
# Speaker Verification docstring
|
||||||
|
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
||||||
|
_XVECTOR_EXPECTED_OUTPUT = 0.97
|
||||||
|
|
||||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"microsoft/wavlm-base",
|
"microsoft/wavlm-base",
|
||||||
"microsoft/wavlm-base-plus",
|
"microsoft/wavlm-base-plus",
|
||||||
@@ -1247,6 +1263,7 @@ class WavLMModel(WavLMPreTrainedModel):
|
|||||||
output_type=WavLMBaseModelOutput,
|
output_type=WavLMBaseModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1350,6 +1367,8 @@ class WavLMForCTC(WavLMPreTrainedModel):
|
|||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
output_type=CausalLMOutput,
|
output_type=CausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1480,6 +1499,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
|
|||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||||
|
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1596,6 +1617,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
|
|||||||
output_type=TokenClassifierOutput,
|
output_type=TokenClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1772,6 +1794,7 @@ class WavLMForXVector(WavLMPreTrainedModel):
|
|||||||
output_type=XVectorOutput,
|
output_type=XVectorOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
|
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,2 +1,7 @@
|
|||||||
docs/source/quicktour.rst
|
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||||
docs/source/task_summary.rst
|
src/transformers/models/hubert/modeling_hubert.py
|
||||||
|
src/transformers/models/wavlm/modeling_wavlm.py
|
||||||
|
src/transformers/models/unispeech/modeling_unispeech.py
|
||||||
|
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
|
||||||
|
src/transformers/models/sew/modeling_sew.py
|
||||||
|
src/transformers/models/sew_d/modeling_sew_d.py
|
||||||
|
|||||||
145
utils/prepare_for_doc_test.py
Normal file
145
utils/prepare_for_doc_test.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Style utils to preprocess files for doc tests.
|
||||||
|
|
||||||
|
The doc precossing function can be run on a list of files and/org
|
||||||
|
directories of files. It will recursively check if the files have
|
||||||
|
a python code snippet by looking for a ```python or ```py syntax.
|
||||||
|
In the default mode - `remove_new_line==False` the script will
|
||||||
|
add a new line before every python code ending ``` line to make
|
||||||
|
the docstrings ready for pytest doctests.
|
||||||
|
However, we don't want to have empty lines displayed in the
|
||||||
|
official documentation which is why the new line command can be
|
||||||
|
reversed by adding the flag `--remove_new_line` which sets
|
||||||
|
`remove_new_line==True`.
|
||||||
|
|
||||||
|
When debugging the doc tests locally, please make sure to
|
||||||
|
always run:
|
||||||
|
|
||||||
|
```python utils/prepare_for_doc_test.py src doc```
|
||||||
|
|
||||||
|
before running the doc tests:
|
||||||
|
|
||||||
|
```pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"```
|
||||||
|
|
||||||
|
Afterwards you should revert the changes by running
|
||||||
|
|
||||||
|
```python utils/prepare_for_doc_test.py src doc --remove_new_line```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def process_code_block(code, add_new_line=True):
|
||||||
|
if add_new_line:
|
||||||
|
return maybe_append_new_line(code)
|
||||||
|
else:
|
||||||
|
return maybe_remove_new_line(code)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_append_new_line(code):
|
||||||
|
"""
|
||||||
|
Append new line if code snippet is a
|
||||||
|
Python code snippet
|
||||||
|
"""
|
||||||
|
lines = code.split("\n")
|
||||||
|
|
||||||
|
if lines[0] in ["py", "python"]:
|
||||||
|
# add new line before last line being ```
|
||||||
|
last_line = lines[-1]
|
||||||
|
lines.pop()
|
||||||
|
lines.append("\n" + last_line)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_remove_new_line(code):
|
||||||
|
"""
|
||||||
|
Remove new line if code snippet is a
|
||||||
|
Python code snippet
|
||||||
|
"""
|
||||||
|
lines = code.split("\n")
|
||||||
|
|
||||||
|
if lines[0] in ["py", "python"]:
|
||||||
|
# add new line before last line being ```
|
||||||
|
lines = lines[:-2] + lines[-1:]
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def process_doc_file(code_file, add_new_line=True):
|
||||||
|
"""
|
||||||
|
Process given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
|
||||||
|
"""
|
||||||
|
with open(code_file, "r", encoding="utf-8", newline="\n") as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
splits = code.split("```")
|
||||||
|
splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)]
|
||||||
|
clean_code = "```".join(splits)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
diff = clean_code != code
|
||||||
|
if diff:
|
||||||
|
print(f"Overwriting content of {code_file}.")
|
||||||
|
with open(code_file, "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
f.write(clean_code)
|
||||||
|
|
||||||
|
|
||||||
|
def process_doc_files(*files, add_new_line=True):
|
||||||
|
"""
|
||||||
|
Applies doc styling or checks everything is correct in a list of files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files (several `str` or `os.PathLike`): The files to treat.
|
||||||
|
Whether to restyle file or just check if they should be restyled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[`str`]: The list of files changed or that should be restyled.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
# Treat folders
|
||||||
|
if os.path.isdir(file):
|
||||||
|
files = [os.path.join(file, f) for f in os.listdir(file)]
|
||||||
|
files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
|
||||||
|
process_doc_files(*files, add_new_line=add_new_line)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
process_doc_file(file, add_new_line=add_new_line)
|
||||||
|
except Exception:
|
||||||
|
print(f"There is a problem in {file}.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main(*files, add_new_line=True):
|
||||||
|
process_doc_files(*files, add_new_line=add_new_line)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--remove_new_line",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to remove new line after each python code block instead of adding one.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(*args.files, add_new_line=not args.remove_new_line)
|
||||||
Reference in New Issue
Block a user