update_pip_test_mapping (#22606)
* Add TFBlipForConditionalGeneration * update pipeline_model_mapping * Add import * Revert changes in GPTSanJapaneseTest --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -231,6 +231,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|||||||
|
|
||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("blip", "TFBlipForConditionalGeneration"),
|
||||||
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from ...test_modeling_common import (
|
|||||||
ids_tensor,
|
ids_tensor,
|
||||||
random_attention_mask,
|
random_attention_mask,
|
||||||
)
|
)
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -420,8 +421,9 @@ class AlignModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class AlignModelTest(ModelTesterMixin, unittest.TestCase):
|
class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (AlignModel,) if is_torch_available() else ()
|
all_model_classes = (AlignModel,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = {"feature-extraction": AlignModel} if is_torch_available() else {}
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -429,9 +429,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
"fill-mask": BartForConditionalGeneration,
|
"fill-mask": BartForConditionalGeneration,
|
||||||
"question-answering": BartForQuestionAnswering,
|
"question-answering": BartForQuestionAnswering,
|
||||||
"summarization": BartForConditionalGeneration,
|
"summarization": BartForConditionalGeneration,
|
||||||
"text2text-generation": BartForConditionalGeneration,
|
|
||||||
"text-classification": BartForSequenceClassification,
|
"text-classification": BartForSequenceClassification,
|
||||||
"text-generation": BartForCausalLM,
|
"text-generation": BartForCausalLM,
|
||||||
|
"text2text-generation": BartForConditionalGeneration,
|
||||||
|
"translation": BartForConditionalGeneration,
|
||||||
"zero-shot": BartForSequenceClassification,
|
"zero-shot": BartForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -199,8 +199,9 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
"conversational": TFBartForConditionalGeneration,
|
"conversational": TFBartForConditionalGeneration,
|
||||||
"feature-extraction": TFBartModel,
|
"feature-extraction": TFBartModel,
|
||||||
"summarization": TFBartForConditionalGeneration,
|
"summarization": TFBartForConditionalGeneration,
|
||||||
"text2text-generation": TFBartForConditionalGeneration,
|
|
||||||
"text-classification": TFBartForSequenceClassification,
|
"text-classification": TFBartForSequenceClassification,
|
||||||
|
"text2text-generation": TFBartForConditionalGeneration,
|
||||||
|
"translation": TFBartForConditionalGeneration,
|
||||||
"zero-shot": TFBartForSequenceClassification,
|
"zero-shot": TFBartForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
|
|||||||
@@ -251,9 +251,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
"feature-extraction": BigBirdPegasusModel,
|
"feature-extraction": BigBirdPegasusModel,
|
||||||
"question-answering": BigBirdPegasusForQuestionAnswering,
|
"question-answering": BigBirdPegasusForQuestionAnswering,
|
||||||
"summarization": BigBirdPegasusForConditionalGeneration,
|
"summarization": BigBirdPegasusForConditionalGeneration,
|
||||||
"text2text-generation": BigBirdPegasusForConditionalGeneration,
|
|
||||||
"text-classification": BigBirdPegasusForSequenceClassification,
|
"text-classification": BigBirdPegasusForSequenceClassification,
|
||||||
"text-generation": BigBirdPegasusForCausalLM,
|
"text-generation": BigBirdPegasusForCausalLM,
|
||||||
|
"text2text-generation": BigBirdPegasusForConditionalGeneration,
|
||||||
|
"translation": BigBirdPegasusForConditionalGeneration,
|
||||||
"zero-shot": BigBirdPegasusForSequenceClassification,
|
"zero-shot": BigBirdPegasusForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -232,8 +232,9 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
"conversational": BlenderbotForConditionalGeneration,
|
"conversational": BlenderbotForConditionalGeneration,
|
||||||
"feature-extraction": BlenderbotModel,
|
"feature-extraction": BlenderbotModel,
|
||||||
"summarization": BlenderbotForConditionalGeneration,
|
"summarization": BlenderbotForConditionalGeneration,
|
||||||
"text2text-generation": BlenderbotForConditionalGeneration,
|
|
||||||
"text-generation": BlenderbotForCausalLM,
|
"text-generation": BlenderbotForCausalLM,
|
||||||
|
"text2text-generation": BlenderbotForConditionalGeneration,
|
||||||
|
"translation": BlenderbotForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
"feature-extraction": TFBlenderbotModel,
|
"feature-extraction": TFBlenderbotModel,
|
||||||
"summarization": TFBlenderbotForConditionalGeneration,
|
"summarization": TFBlenderbotForConditionalGeneration,
|
||||||
"text2text-generation": TFBlenderbotForConditionalGeneration,
|
"text2text-generation": TFBlenderbotForConditionalGeneration,
|
||||||
|
"translation": TFBlenderbotForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -226,8 +226,9 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||||||
"conversational": BlenderbotSmallForConditionalGeneration,
|
"conversational": BlenderbotSmallForConditionalGeneration,
|
||||||
"feature-extraction": BlenderbotSmallModel,
|
"feature-extraction": BlenderbotSmallModel,
|
||||||
"summarization": BlenderbotSmallForConditionalGeneration,
|
"summarization": BlenderbotSmallForConditionalGeneration,
|
||||||
"text2text-generation": BlenderbotSmallForConditionalGeneration,
|
|
||||||
"text-generation": BlenderbotSmallForCausalLM,
|
"text-generation": BlenderbotSmallForCausalLM,
|
||||||
|
"text2text-generation": BlenderbotSmallForConditionalGeneration,
|
||||||
|
"translation": BlenderbotSmallForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
|||||||
"feature-extraction": TFBlenderbotSmallModel,
|
"feature-extraction": TFBlenderbotSmallModel,
|
||||||
"summarization": TFBlenderbotSmallForConditionalGeneration,
|
"summarization": TFBlenderbotSmallForConditionalGeneration,
|
||||||
"text2text-generation": TFBlenderbotSmallForConditionalGeneration,
|
"text2text-generation": TFBlenderbotSmallForConditionalGeneration,
|
||||||
|
"translation": TFBlenderbotSmallForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
|
|||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -164,7 +165,7 @@ class ConvNextV2ModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase):
|
class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNextV2 does not use input_ids, inputs_embeds,
|
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNextV2 does not use input_ids, inputs_embeds,
|
||||||
attention_mask and seq_length.
|
attention_mask and seq_length.
|
||||||
@@ -179,6 +180,11 @@ class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av
|
|||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -122,13 +123,18 @@ class EfficientNetModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class EfficientNetModelTest(ModelTesterMixin, unittest.TestCase):
|
class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Here we also overwrite some of the tests of test_modeling_common.py, as EfficientNet does not use input_ids, inputs_embeds,
|
Here we also overwrite some of the tests of test_modeling_common.py, as EfficientNet does not use input_ids, inputs_embeds,
|
||||||
attention_mask and seq_length.
|
attention_mask and seq_length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
"feature-extraction": FSMTModel,
|
"feature-extraction": FSMTModel,
|
||||||
"summarization": FSMTForConditionalGeneration,
|
"summarization": FSMTForConditionalGeneration,
|
||||||
"text2text-generation": FSMTForConditionalGeneration,
|
"text2text-generation": FSMTForConditionalGeneration,
|
||||||
|
"translation": FSMTForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class GPTSanJapaneseTester:
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return GPTSanJapaneseConfig(
|
return GPTSanJapaneseConfig(
|
||||||
vocab_size=36000,
|
vocab_size=self.vocab_size,
|
||||||
num_contexts=self.seq_length,
|
num_contexts=self.seq_length,
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
d_ff=self.d_ff,
|
d_ff=self.d_ff,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers.testing_utils import is_flaky, require_torch, slow, torch_devi
|
|||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
@@ -177,9 +178,10 @@ class InformerModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class InformerModelTest(ModelTesterMixin, unittest.TestCase):
|
class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
|
all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
|
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = {"feature-extraction": InformerModel} if is_torch_available() else {}
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -282,8 +282,9 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
"feature-extraction": LEDModel,
|
"feature-extraction": LEDModel,
|
||||||
"question-answering": LEDForQuestionAnswering,
|
"question-answering": LEDForQuestionAnswering,
|
||||||
"summarization": LEDForConditionalGeneration,
|
"summarization": LEDForConditionalGeneration,
|
||||||
"text2text-generation": LEDForConditionalGeneration,
|
|
||||||
"text-classification": LEDForSequenceClassification,
|
"text-classification": LEDForSequenceClassification,
|
||||||
|
"text2text-generation": LEDForConditionalGeneration,
|
||||||
|
"translation": LEDForConditionalGeneration,
|
||||||
"zero-shot": LEDForSequenceClassification,
|
"zero-shot": LEDForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
"feature-extraction": TFLEDModel,
|
"feature-extraction": TFLEDModel,
|
||||||
"summarization": TFLEDForConditionalGeneration,
|
"summarization": TFLEDForConditionalGeneration,
|
||||||
"text2text-generation": TFLEDForConditionalGeneration,
|
"text2text-generation": TFLEDForConditionalGeneration,
|
||||||
|
"translation": TFLEDForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -510,6 +510,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
"feature-extraction": LongT5Model,
|
"feature-extraction": LongT5Model,
|
||||||
"summarization": LongT5ForConditionalGeneration,
|
"summarization": LongT5ForConditionalGeneration,
|
||||||
"text2text-generation": LongT5ForConditionalGeneration,
|
"text2text-generation": LongT5ForConditionalGeneration,
|
||||||
|
"translation": LongT5ForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
"feature-extraction": M2M100Model,
|
"feature-extraction": M2M100Model,
|
||||||
"summarization": M2M100ForConditionalGeneration,
|
"summarization": M2M100ForConditionalGeneration,
|
||||||
"text2text-generation": M2M100ForConditionalGeneration,
|
"text2text-generation": M2M100ForConditionalGeneration,
|
||||||
|
"translation": M2M100ForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -244,8 +244,9 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
"conversational": MarianMTModel,
|
"conversational": MarianMTModel,
|
||||||
"feature-extraction": MarianModel,
|
"feature-extraction": MarianModel,
|
||||||
"summarization": MarianMTModel,
|
"summarization": MarianMTModel,
|
||||||
"text2text-generation": MarianMTModel,
|
|
||||||
"text-generation": MarianForCausalLM,
|
"text-generation": MarianForCausalLM,
|
||||||
|
"text2text-generation": MarianMTModel,
|
||||||
|
"translation": MarianMTModel,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
"feature-extraction": TFMarianModel,
|
"feature-extraction": TFMarianModel,
|
||||||
"summarization": TFMarianMTModel,
|
"summarization": TFMarianMTModel,
|
||||||
"text2text-generation": TFMarianMTModel,
|
"text2text-generation": TFMarianMTModel,
|
||||||
|
"translation": TFMarianMTModel,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -239,9 +239,10 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"fill-mask": MBartForConditionalGeneration,
|
"fill-mask": MBartForConditionalGeneration,
|
||||||
"question-answering": MBartForQuestionAnswering,
|
"question-answering": MBartForQuestionAnswering,
|
||||||
"summarization": MBartForConditionalGeneration,
|
"summarization": MBartForConditionalGeneration,
|
||||||
"text2text-generation": MBartForConditionalGeneration,
|
|
||||||
"text-classification": MBartForSequenceClassification,
|
"text-classification": MBartForSequenceClassification,
|
||||||
"text-generation": MBartForCausalLM,
|
"text-generation": MBartForCausalLM,
|
||||||
|
"text2text-generation": MBartForConditionalGeneration,
|
||||||
|
"translation": MBartForConditionalGeneration,
|
||||||
"zero-shot": MBartForSequenceClassification,
|
"zero-shot": MBartForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||||||
"feature-extraction": TFMBartModel,
|
"feature-extraction": TFMBartModel,
|
||||||
"summarization": TFMBartForConditionalGeneration,
|
"summarization": TFMBartForConditionalGeneration,
|
||||||
"text2text-generation": TFMBartForConditionalGeneration,
|
"text2text-generation": TFMBartForConditionalGeneration,
|
||||||
|
"translation": TFMBartForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -471,9 +471,11 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MegaModel,
|
"feature-extraction": MegaModel,
|
||||||
|
"fill-mask": MegaForMaskedLM,
|
||||||
"question-answering": MegaForQuestionAnswering,
|
"question-answering": MegaForQuestionAnswering,
|
||||||
"text-classification": MegaForSequenceClassification,
|
"text-classification": MegaForSequenceClassification,
|
||||||
"text-generation": MegaForCausalLM,
|
"text-generation": MegaForCausalLM,
|
||||||
|
"token-classification": MegaForTokenClassification,
|
||||||
"zero-shot": MegaForSequenceClassification,
|
"zero-shot": MegaForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from transformers.utils import is_torch_available, is_vision_available
|
|||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -116,8 +117,9 @@ class MgpstrModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MgpstrModelTest(ModelTesterMixin, unittest.TestCase):
|
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
|
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -420,9 +420,10 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
"fill-mask": MvpForConditionalGeneration,
|
"fill-mask": MvpForConditionalGeneration,
|
||||||
"question-answering": MvpForQuestionAnswering,
|
"question-answering": MvpForQuestionAnswering,
|
||||||
"summarization": MvpForConditionalGeneration,
|
"summarization": MvpForConditionalGeneration,
|
||||||
"text2text-generation": MvpForConditionalGeneration,
|
|
||||||
"text-classification": MvpForSequenceClassification,
|
"text-classification": MvpForSequenceClassification,
|
||||||
"text-generation": MvpForCausalLM,
|
"text-generation": MvpForCausalLM,
|
||||||
|
"text2text-generation": MvpForConditionalGeneration,
|
||||||
|
"translation": MvpForConditionalGeneration,
|
||||||
"zero-shot": MvpForSequenceClassification,
|
"zero-shot": MvpForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"feature-extraction": NllbMoeModel,
|
"feature-extraction": NllbMoeModel,
|
||||||
"summarization": NllbMoeForConditionalGeneration,
|
"summarization": NllbMoeForConditionalGeneration,
|
||||||
"text2text-generation": NllbMoeForConditionalGeneration,
|
"text2text-generation": NllbMoeForConditionalGeneration,
|
||||||
|
"translation": NllbMoeForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -242,8 +242,9 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"conversational": PegasusForConditionalGeneration,
|
"conversational": PegasusForConditionalGeneration,
|
||||||
"feature-extraction": PegasusModel,
|
"feature-extraction": PegasusModel,
|
||||||
"summarization": PegasusForConditionalGeneration,
|
"summarization": PegasusForConditionalGeneration,
|
||||||
"text2text-generation": PegasusForConditionalGeneration,
|
|
||||||
"text-generation": PegasusForCausalLM,
|
"text-generation": PegasusForCausalLM,
|
||||||
|
"text2text-generation": PegasusForConditionalGeneration,
|
||||||
|
"translation": PegasusForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
"feature-extraction": TFPegasusModel,
|
"feature-extraction": TFPegasusModel,
|
||||||
"summarization": TFPegasusForConditionalGeneration,
|
"summarization": TFPegasusForConditionalGeneration,
|
||||||
"text2text-generation": TFPegasusForConditionalGeneration,
|
"text2text-generation": TFPegasusForConditionalGeneration,
|
||||||
|
"translation": TFPegasusForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
"feature-extraction": PegasusXModel,
|
"feature-extraction": PegasusXModel,
|
||||||
"summarization": PegasusXForConditionalGeneration,
|
"summarization": PegasusXForConditionalGeneration,
|
||||||
"text2text-generation": PegasusXForConditionalGeneration,
|
"text2text-generation": PegasusXForConditionalGeneration,
|
||||||
|
"translation": PegasusXForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -224,9 +224,10 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
"conversational": PLBartForConditionalGeneration,
|
"conversational": PLBartForConditionalGeneration,
|
||||||
"feature-extraction": PLBartModel,
|
"feature-extraction": PLBartModel,
|
||||||
"summarization": PLBartForConditionalGeneration,
|
"summarization": PLBartForConditionalGeneration,
|
||||||
"text2text-generation": PLBartForConditionalGeneration,
|
|
||||||
"text-classification": PLBartForSequenceClassification,
|
"text-classification": PLBartForSequenceClassification,
|
||||||
"text-generation": PLBartForCausalLM,
|
"text-generation": PLBartForCausalLM,
|
||||||
|
"text2text-generation": PLBartForConditionalGeneration,
|
||||||
|
"translation": PLBartForConditionalGeneration,
|
||||||
"zero-shot": PLBartForSequenceClassification,
|
"zero-shot": PLBartForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
|||||||
@@ -894,8 +894,9 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
"conversational": ProphetNetForConditionalGeneration,
|
"conversational": ProphetNetForConditionalGeneration,
|
||||||
"feature-extraction": ProphetNetModel,
|
"feature-extraction": ProphetNetModel,
|
||||||
"summarization": ProphetNetForConditionalGeneration,
|
"summarization": ProphetNetForConditionalGeneration,
|
||||||
"text2text-generation": ProphetNetForConditionalGeneration,
|
|
||||||
"text-generation": ProphetNetForCausalLM,
|
"text-generation": ProphetNetForCausalLM,
|
||||||
|
"text2text-generation": ProphetNetForConditionalGeneration,
|
||||||
|
"translation": ProphetNetForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
"feature-extraction": SwitchTransformersModel,
|
"feature-extraction": SwitchTransformersModel,
|
||||||
"summarization": SwitchTransformersForConditionalGeneration,
|
"summarization": SwitchTransformersForConditionalGeneration,
|
||||||
"text2text-generation": SwitchTransformersForConditionalGeneration,
|
"text2text-generation": SwitchTransformersForConditionalGeneration,
|
||||||
|
"translation": SwitchTransformersForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -528,6 +528,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
"feature-extraction": T5Model,
|
"feature-extraction": T5Model,
|
||||||
"summarization": T5ForConditionalGeneration,
|
"summarization": T5ForConditionalGeneration,
|
||||||
"text2text-generation": T5ForConditionalGeneration,
|
"text2text-generation": T5ForConditionalGeneration,
|
||||||
|
"translation": T5ForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -250,6 +250,7 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
"feature-extraction": TFT5Model,
|
"feature-extraction": TFT5Model,
|
||||||
"summarization": TFT5ForConditionalGeneration,
|
"summarization": TFT5ForConditionalGeneration,
|
||||||
"text2text-generation": TFT5ForConditionalGeneration,
|
"text2text-generation": TFT5ForConditionalGeneration,
|
||||||
|
"translation": TFT5ForConditionalGeneration,
|
||||||
}
|
}
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -277,7 +277,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"automatic-speech-recognition": WhisperForConditionalGeneration, "feature-extraction": WhisperModel}
|
{
|
||||||
|
"audio-classification": WhisperForAudioClassification,
|
||||||
|
"automatic-speech-recognition": WhisperForConditionalGeneration,
|
||||||
|
"feature-extraction": WhisperModel,
|
||||||
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user