Update all references to canonical models (#29001)
* Script & Manual edition * Update
This commit is contained in:
@@ -228,7 +228,7 @@ class SomeClass:
|
||||
)
|
||||
|
||||
def test_replace_model_patterns(self):
|
||||
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
|
||||
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
|
||||
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
|
||||
bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
@@ -312,14 +312,14 @@ GPT_NEW_NEW_CONSTANT = "value"
|
||||
# in others.
|
||||
self.assertEqual(replacements, "")
|
||||
|
||||
roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta")
|
||||
roberta_model_patterns = ModelPatterns("RoBERTa", "FacebookAI/roberta-base", model_camel_cased="Roberta")
|
||||
new_roberta_model_patterns = ModelPatterns(
|
||||
"RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew"
|
||||
)
|
||||
roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta
|
||||
class RobertaModel(RobertaPreTrainedModel):
|
||||
""" The base RoBERTa model. """
|
||||
checkpoint = roberta-base
|
||||
checkpoint = FacebookAI/roberta-base
|
||||
base_model_prefix = "roberta"
|
||||
'''
|
||||
roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew
|
||||
@@ -346,7 +346,7 @@ class RobertaNewModel(RobertaNewPreTrainedModel):
|
||||
get_module_from_file("/models/gpt2/modeling_gpt2.py")
|
||||
|
||||
def test_duplicate_module(self):
|
||||
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
|
||||
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
|
||||
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
|
||||
bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
@@ -395,7 +395,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
self.check_result(dest_file_name, bert_expected)
|
||||
|
||||
def test_duplicate_module_with_copied_from(self):
|
||||
bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
|
||||
bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
|
||||
new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
|
||||
bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert
|
||||
class TFBertPreTrainedModel(PreTrainedModel):
|
||||
@@ -656,7 +656,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
self.assertEqual(test_files, wav2vec2_test_files)
|
||||
|
||||
def test_find_base_model_checkpoint(self):
|
||||
self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased")
|
||||
self.assertEqual(find_base_model_checkpoint("bert"), "google-bert/bert-base-uncased")
|
||||
self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2")
|
||||
|
||||
def test_retrieve_model_classes(self):
|
||||
@@ -719,7 +719,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
bert_model_patterns = bert_info["model_patterns"]
|
||||
self.assertEqual(bert_model_patterns.model_name, "BERT")
|
||||
self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
|
||||
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
|
||||
self.assertEqual(bert_model_patterns.model_type, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
|
||||
@@ -768,7 +768,7 @@ NEW_BERT_CONSTANT = "value"
|
||||
|
||||
bert_model_patterns = bert_info["model_patterns"]
|
||||
self.assertEqual(bert_model_patterns.model_name, "BERT")
|
||||
self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
|
||||
self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
|
||||
self.assertEqual(bert_model_patterns.model_type, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
|
||||
self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
|
||||
|
||||
@@ -105,7 +105,7 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
|
||||
def test_get_file_from_repo_distant(self):
|
||||
# `get_file_from_repo` returns None if the file does not exist
|
||||
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
|
||||
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt"))
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
@@ -113,9 +113,9 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
get_file_from_repo("bert-base-cased", CONFIG_NAME, revision="ahaha")
|
||||
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha")
|
||||
|
||||
resolved_file = get_file_from_repo("bert-base-cased", CONFIG_NAME)
|
||||
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME)
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file, "r").read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
|
||||
Reference in New Issue
Block a user