Add more missing models to models/__init__.py (#14177)
* Add missing models to models/__init__.py * Fix issues previously undetected * Add UniSpeechSatForPreTraining to all_model_classes * fix unispeech sat * fix * Add check_model_list() to check_repo.py * Remove _ignore_models = ["bort"] Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from . import (
|
||||
bigbird_pegasus,
|
||||
blenderbot,
|
||||
blenderbot_small,
|
||||
byt5,
|
||||
camembert,
|
||||
canine,
|
||||
clip,
|
||||
@@ -67,6 +68,7 @@ from . import (
|
||||
mbart,
|
||||
mbart50,
|
||||
megatron_bert,
|
||||
megatron_gpt2,
|
||||
mmbt,
|
||||
mobilebert,
|
||||
mpnet,
|
||||
@@ -84,12 +86,17 @@ from . import (
|
||||
segformer,
|
||||
sew,
|
||||
sew_d,
|
||||
speech_encoder_decoder,
|
||||
speech_to_text,
|
||||
speech_to_text_2,
|
||||
splinter,
|
||||
squeezebert,
|
||||
t5,
|
||||
tapas,
|
||||
transfo_xl,
|
||||
trocr,
|
||||
unispeech,
|
||||
unispeech_sat,
|
||||
vision_encoder_decoder,
|
||||
visual_bert,
|
||||
vit,
|
||||
|
||||
@@ -1157,6 +1157,7 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
||||
|
||||
self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim)
|
||||
self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim))
|
||||
self.label_embeddings_concat.data.zero_()
|
||||
|
||||
self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
if self.config.do_stable_layer_norm:
|
||||
@@ -1268,21 +1269,24 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
||||
# quantize all (unmasked) extracted features and project to final vq dim
|
||||
extract_features = self.dropout_features(outputs[1])
|
||||
|
||||
# layer normalization (has no effect when `config.do_stable_layer_norm == False`)
|
||||
extract_features = self.layer_norm_for_extract(extract_features)
|
||||
quantized_features, codevector_perplexity = self.quantizer(extract_features)
|
||||
|
||||
# project quantized features twice
|
||||
quantized_features = self.project_q(quantized_features)
|
||||
quantized_features = self.project_hid(quantized_features)
|
||||
|
||||
# TODO(PVP) - add pretraining logic and add to tests
|
||||
loss = None
|
||||
logits = None
|
||||
logits = extract_features
|
||||
loss = quantized_features = codevector_perplexity = None
|
||||
|
||||
# layer normalization (has no effect when `config.do_stable_layer_norm == False`)
|
||||
# extract_features = self.layer_norm_for_extract(extract_features)
|
||||
# quantized_features, codevector_perplexity = self.quantizer(extract_features)
|
||||
#
|
||||
# project quantized features twice
|
||||
# quantized_features = self.project_q(quantized_features)
|
||||
# quantized_features = self.project_hid(quantized_features)
|
||||
#
|
||||
# loss = None
|
||||
# logits = quantized_features
|
||||
if not return_dict:
|
||||
if loss is not None:
|
||||
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
||||
|
||||
return UniSpeechSatForPreTrainingOutput(
|
||||
loss=loss,
|
||||
|
||||
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
@@ -293,7 +294,9 @@ class UniSpeechSatModelTester:
|
||||
@require_torch
|
||||
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -407,6 +410,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -490,7 +494,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatModel, UniSpeechSatForSequenceClassification) if is_torch_available() else ()
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -610,6 +616,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
|
||||
@@ -73,9 +73,11 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
]
|
||||
|
||||
@@ -148,6 +150,26 @@ spec = importlib.util.spec_from_file_location(
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
_models = []
|
||||
for model in os.listdir(models_dir):
|
||||
model_dir = os.path.join(models_dir, model)
|
||||
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
||||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(list(set(_models).difference(models)))
|
||||
if missing_models:
|
||||
raise Exception(
|
||||
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
|
||||
)
|
||||
|
||||
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
@@ -163,6 +185,7 @@ def get_model_modules():
|
||||
"modeling_flax_auto",
|
||||
"modeling_flax_encoder_decoder",
|
||||
"modeling_flax_utils",
|
||||
"modeling_speech_encoder_decoder",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
"modeling_tf_encoder_decoder",
|
||||
@@ -560,6 +583,8 @@ def check_all_objects_are_documented():
|
||||
|
||||
def check_repo_quality():
|
||||
"""Check all models are properly tested and documented."""
|
||||
print("Checking all models are included.")
|
||||
check_model_list()
|
||||
print("Checking all models are public.")
|
||||
check_models_are_in_init()
|
||||
print("Checking all models are properly tested.")
|
||||
|
||||
Reference in New Issue
Block a user