[tests] remove pt_tf equivalence tests (#36253)

This commit is contained in:
Joao Gante
2025-02-19 11:55:11 +00:00
committed by GitHub
parent 1a81d774b1
commit 0863eef248
60 changed files with 56 additions and 2438 deletions

View File

@@ -736,55 +736,6 @@ NEW_BERT_CONSTANT = "value"
self.assertIsNone(bert_model_patterns.feature_extractor_class)
self.assertIsNone(bert_model_patterns.processor_class)
def test_retrieve_info_for_model_pt_tf_with_bert(self):
bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"])
bert_classes = [
"BertForTokenClassification",
"BertForQuestionAnswering",
"BertForNextSentencePrediction",
"BertForSequenceClassification",
"BertForMaskedLM",
"BertForMultipleChoice",
"BertModel",
"BertForPreTraining",
"BertLMHeadModel",
]
expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}}
self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"})
model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
self.assertEqual(model_classes, expected_model_classes)
all_bert_files = bert_info["model_files"]
model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"}
self.assertEqual(model_files, bert_model_files)
test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
bert_test_files = {
"tests/models/bert/test_tokenization_bert.py",
"tests/models/bert/test_modeling_bert.py",
"tests/models/bert/test_modeling_tf_bert.py",
}
self.assertEqual(test_files, bert_test_files)
doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
self.assertEqual(all_bert_files["module_name"], "bert")
bert_model_patterns = bert_info["model_patterns"]
self.assertEqual(bert_model_patterns.model_name, "BERT")
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
self.assertEqual(bert_model_patterns.model_type, "bert")
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
self.assertEqual(bert_model_patterns.config_class, "BertConfig")
self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
self.assertIsNone(bert_model_patterns.feature_extractor_class)
self.assertIsNone(bert_model_patterns.processor_class)
def test_retrieve_info_for_model_with_vit(self):
vit_info = retrieve_info_for_model("vit")
vit_classes = ["ViTForImageClassification", "ViTModel"]