[tests] remove pt_tf equivalence tests (#36253)
This commit is contained in:
@@ -736,55 +736,6 @@ NEW_BERT_CONSTANT = "value"
|
||||
self.assertIsNone(bert_model_patterns.feature_extractor_class)
|
||||
self.assertIsNone(bert_model_patterns.processor_class)
|
||||
|
||||
def test_retrieve_info_for_model_pt_tf_with_bert(self):
|
||||
bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"])
|
||||
bert_classes = [
|
||||
"BertForTokenClassification",
|
||||
"BertForQuestionAnswering",
|
||||
"BertForNextSentencePrediction",
|
||||
"BertForSequenceClassification",
|
||||
"BertForMaskedLM",
|
||||
"BertForMultipleChoice",
|
||||
"BertModel",
|
||||
"BertForPreTraining",
|
||||
"BertLMHeadModel",
|
||||
]
|
||||
expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}}
|
||||
|
||||
self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"})
|
||||
model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
|
||||
self.assertEqual(model_classes, expected_model_classes)
|
||||
|
||||
all_bert_files = bert_info["model_files"]
|
||||
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
|
||||
bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"}
|
||||
self.assertEqual(model_files, bert_model_files)
|
||||
|
||||
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
|
||||
bert_test_files = {
|
||||
"tests/models/bert/test_tokenization_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_tf_bert.py",
|
||||
}
|
||||
self.assertEqual(test_files, bert_test_files)
|
||||
|
||||
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
|
||||
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
|
||||
|
||||
self.assertEqual(all_bert_files["module_name"], "bert")
|
||||
|
||||
bert_model_patterns = bert_info["model_patterns"]
|
||||
self.assertEqual(bert_model_patterns.model_name, "BERT")
|
||||
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
|
||||
self.assertEqual(bert_model_patterns.model_type, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
|
||||
self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
|
||||
self.assertEqual(bert_model_patterns.config_class, "BertConfig")
|
||||
self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
|
||||
self.assertIsNone(bert_model_patterns.feature_extractor_class)
|
||||
self.assertIsNone(bert_model_patterns.processor_class)
|
||||
|
||||
def test_retrieve_info_for_model_with_vit(self):
|
||||
vit_info = retrieve_info_for_model("vit")
|
||||
vit_classes = ["ViTForImageClassification", "ViTModel"]
|
||||
|
||||
Reference in New Issue
Block a user