Change REALM checkpoint to new ones (#15439)
* Change REALM checkpoint to new ones * Last checkpoint missing
This commit is contained in:
@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
|
||||
config.return_dict = True
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
|
||||
# RealmKnowledgeAugEncoder training
|
||||
model = RealmKnowledgeAugEncoder(config)
|
||||
@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_embedder_from_pretrained(self):
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_encoder_from_pretrained(self):
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_open_qa_from_pretrained(self):
|
||||
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_reader_from_pretrained(self):
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
|
||||
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_scorer_from_pretrained(self):
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
|
||||
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_embedder(self):
|
||||
retriever_projected_size = 128
|
||||
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
vocab_size = 30522
|
||||
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained(
|
||||
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
|
||||
"google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
|
||||
)
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
|
||||
@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
config = RealmConfig()
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
|
||||
model = RealmForOpenQA.from_pretrained(
|
||||
"qqaatw/realm-orqa-nq-openqa",
|
||||
"google/realm-orqa-nq-openqa",
|
||||
retriever=retriever,
|
||||
config=config,
|
||||
)
|
||||
@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_reader(self):
|
||||
config = RealmConfig(reader_beam_size=2, max_span_width=3)
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config)
|
||||
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
|
||||
|
||||
concat_input_ids = torch.arange(10).view((2, 5))
|
||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||
@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_scorer(self):
|
||||
num_candidates = 2
|
||||
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
|
||||
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
|
||||
@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
|
||||
mock_hf_hub_download.return_value = os.path.join(
|
||||
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
|
||||
)
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")
|
||||
|
||||
self.assertEqual(retriever.block_records[0], b"This is the first record")
|
||||
|
||||
Reference in New Issue
Block a user