Change REALM checkpoint to new ones (#15439)

* Change REALM checkpoint to new ones

* Last checkpoint missing
This commit is contained in:
Sylvain Gugger
2022-01-31 12:50:20 -05:00
committed by GitHub
parent 7e56ba2864
commit 3385ca2582
6 changed files with 102 additions and 102 deletions

View File

@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
config.return_dict = True
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
# RealmKnowledgeAugEncoder training
model = RealmKnowledgeAugEncoder(config)
@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
@slow
def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
self.assertIsNotNone(model)
@slow
def test_encoder_from_pretrained(self):
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
self.assertIsNotNone(model)
@slow
def test_open_qa_from_pretrained(self):
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
self.assertIsNotNone(model)
@slow
def test_reader_from_pretrained(self):
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
self.assertIsNotNone(model)
@slow
def test_scorer_from_pretrained(self):
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
self.assertIsNotNone(model)
@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_embedder(self):
retriever_projected_size = 128
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
vocab_size = 30522
model = RealmKnowledgeAugEncoder.from_pretrained(
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
"google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
config = RealmConfig()
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained(
"qqaatw/realm-orqa-nq-openqa",
"google/realm-orqa-nq-openqa",
retriever=retriever,
config=config,
)
@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_reader(self):
config = RealmConfig(reader_beam_size=2, max_span_width=3)
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config)
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_scorer(self):
num_candidates = 2
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])

View File

@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
mock_hf_hub_download.return_value = os.path.join(
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
)
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa")
retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")
self.assertEqual(retriever.block_records[0], b"This is the first record")