[RAG] Fix rag from pretrained question encoder generator behavior (#11962)

* fix_torch_device_generate_test

* remove @

* fix rag from pretrained loading

* add test

* uplaod

* finish
This commit is contained in:
Patrick von Platen
2021-06-02 09:17:14 +01:00
committed by GitHub
parent 6db3a87de2
commit 43f46aa7fd
2 changed files with 16 additions and 5 deletions

View File

@@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
"facebook/bart-large-cnn",
retriever=rag_retriever,
config=rag_config,
question_encoder_max_length=200,
generator_max_length=200,
).to(torch_device)
# check that the from pretrained methods work
rag_token.save_pretrained(tmp_dirname)
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
rag_token.to(torch_device)
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
self.assertTrue(rag_token.generator.config.max_length == 200)
with torch.no_grad():
output = rag_token(
input_ids,