Add more missing models to models/__init__.py (#14177)
* Add missing models to models/__init__.py * Fix issues previously undetected * Add UniSpeechSatForPreTraining to all_model_classes * fix unispeech sat * fix * Add check_model_list() to check_repo.py * Remove _ignore_models = ["bort"] Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
@@ -293,7 +294,9 @@ class UniSpeechSatModelTester:
|
||||
@require_torch
|
||||
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -407,6 +410,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -490,7 +494,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -610,6 +616,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
|
||||
Reference in New Issue
Block a user