Update all references to canonical models (#29001)

* Script & Manual edition

* Update
This commit is contained in:
Lysandre Debut
2024-02-16 08:16:58 +01:00
committed by GitHub
parent 1e402b957d
commit f497f564bb
561 changed files with 2682 additions and 2687 deletions

View File

@@ -228,7 +228,7 @@ class SomeClass:
)
def test_replace_model_patterns(self):
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
"""
@@ -312,14 +312,14 @@ GPT_NEW_NEW_CONSTANT = "value"
# in others.
self.assertEqual(replacements, "")
roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta")
roberta_model_patterns = ModelPatterns("RoBERTa", "FacebookAI/roberta-base", model_camel_cased="Roberta")
new_roberta_model_patterns = ModelPatterns(
"RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew"
)
roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta
class RobertaModel(RobertaPreTrainedModel):
""" The base RoBERTa model. """
checkpoint = roberta-base
checkpoint = FacebookAI/roberta-base
base_model_prefix = "roberta"
'''
roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew
@@ -346,7 +346,7 @@ class RobertaNewModel(RobertaNewPreTrainedModel):
get_module_from_file("/models/gpt2/modeling_gpt2.py")
def test_duplicate_module(self):
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
"""
@@ -395,7 +395,7 @@ NEW_BERT_CONSTANT = "value"
self.check_result(dest_file_name, bert_expected)
def test_duplicate_module_with_copied_from(self):
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert
class TFBertPreTrainedModel(PreTrainedModel):
@@ -656,7 +656,7 @@ NEW_BERT_CONSTANT = "value"
self.assertEqual(test_files, wav2vec2_test_files)
def test_find_base_model_checkpoint(self):
self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased")
self.assertEqual(find_base_model_checkpoint("bert"), "google-bert/bert-base-uncased")
self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2")
def test_retrieve_model_classes(self):
@@ -719,7 +719,7 @@ NEW_BERT_CONSTANT = "value"
bert_model_patterns = bert_info["model_patterns"]
self.assertEqual(bert_model_patterns.model_name, "BERT")
self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
self.assertEqual(bert_model_patterns.model_type, "bert")
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
@@ -768,7 +768,7 @@ NEW_BERT_CONSTANT = "value"
bert_model_patterns = bert_info["model_patterns"]
self.assertEqual(bert_model_patterns.model_name, "BERT")
self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
self.assertEqual(bert_model_patterns.model_type, "bert")
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")

View File

@@ -105,7 +105,7 @@ class GetFromCacheTests(unittest.TestCase):
def test_get_file_from_repo_distant(self):
# `get_file_from_repo` returns None if the file does not exist
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt"))
# The function raises if the repository does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
@@ -113,9 +113,9 @@ class GetFromCacheTests(unittest.TestCase):
# The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
get_file_from_repo("bert-base-cased", CONFIG_NAME, revision="ahaha")
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha")
resolved_file = get_file_from_repo("bert-base-cased", CONFIG_NAME)
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME)
# The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file, "r").read())
self.assertEqual(config["hidden_size"], 768)