[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)

* fix rag

* fix slow test

* fix past in bart
This commit is contained in:
Patrick von Platen
2020-12-14 12:32:26 +01:00
committed by GitHub
parent 6587cf9f84
commit fa1ddced9e
3 changed files with 33 additions and 39 deletions

View File

@@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
@@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
)
return {
@@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
generator_tokenizer=rag_decoder_tokenizer,
)
rag_token = self.sequence_model
rag_token.set_retriever(rag_retriever)
rag_sequence = self.sequence_model
rag_sequence.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt"
@@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate(
output_ids = rag_sequence.generate(
input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=2,
num_return_sequences=2,
)
@@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
)
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device
)
@@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
" walls of the abdomen",
" spodumene",
" obama",
" grainger's compound",
" new orleans",
" japan",
" old trafford stadium",
" old trafford",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)