Improve mismatched sizes management when loading a pretrained model (#17257)

- Add --ignore_mismatched_sizes argument to classification examples

- Expand the error message when loading a model whose head dimensions are different from expected dimensions
This commit is contained in:
regisss
2022-05-17 17:58:14 +02:00
committed by GitHub
parent 1f13ba818e
commit 28a0811652
13 changed files with 64 additions and 9 deletions

View File

@@ -18,13 +18,13 @@ limitations under the License.
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
very little annotated data to yield good performance on speech classification datasets.
## Single-GPU
## Single-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the 🗣️ [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
👀 See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
## Multi-GPU
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
## Multi-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for 🌎 **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |

View File

@@ -163,6 +163,10 @@ class ModelArguments:
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
@@ -333,6 +337,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder