Add more missing models to models/__init__.py (#14177)

* Add missing models to models/__init__.py

* Fix issues previously undetected

* Add UniSpeechSatForPreTraining to all_model_classes

* fix unispeech sat

* fix

* Add check_model_list() to check_repo.py

* Remove _ignore_models = ["bort"]

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Yih-Dar
2021-11-01 11:52:36 +01:00
committed by GitHub
parent 9fc1951711
commit 9450bfcc6c
4 changed files with 57 additions and 14 deletions

View File

@@ -34,6 +34,7 @@ if is_torch_available():
from transformers import (
UniSpeechSatForCTC,
UniSpeechSatForPreTraining,
UniSpeechSatForSequenceClassification,
UniSpeechSatModel,
Wav2Vec2FeatureExtractor,
@@ -293,7 +294,9 @@ class UniSpeechSatModelTester:
@require_torch
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
@@ -407,6 +410,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"label_embeddings_concat",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@@ -490,7 +494,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
@@ -610,6 +616,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"label_embeddings_concat",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):