[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)
* fix rag * fix slow test * fix past in bart
This commit is contained in:
committed by
GitHub
parent
6587cf9f84
commit
fa1ddced9e
@@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
n_docs=self.n_docs,
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
max_combined_length=self.max_combined_length,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
generator_tokenizer=rag_decoder_tokenizer,
|
||||
)
|
||||
|
||||
rag_token = self.sequence_model
|
||||
rag_token.set_retriever(rag_retriever)
|
||||
rag_sequence = self.sequence_model
|
||||
rag_sequence.set_retriever(rag_retriever)
|
||||
|
||||
input_ids = rag_question_encoder_tokenizer(
|
||||
"who sings does he love me with reba", return_tensors="pt"
|
||||
@@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_ids = rag_token.generate(
|
||||
output_ids = rag_sequence.generate(
|
||||
input_ids,
|
||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
||||
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
@@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||
)
|
||||
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
@@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
" walls of the abdomen",
|
||||
" spodumene",
|
||||
" obama",
|
||||
" grainger's compound",
|
||||
" new orleans",
|
||||
" japan",
|
||||
" old trafford stadium",
|
||||
" old trafford",
|
||||
]
|
||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user