[RAG] Fix rag from pretrained question encoder generator behavior (#11962)
* fix_torch_device_generate_test * remove @ * fix rag from pretrained loading * add test * uplaod * finish
This commit is contained in:
committed by
GitHub
parent
6db3a87de2
commit
43f46aa7fd
@@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
|
||||
"facebook/bart-large-cnn",
|
||||
retriever=rag_retriever,
|
||||
config=rag_config,
|
||||
question_encoder_max_length=200,
|
||||
generator_max_length=200,
|
||||
).to(torch_device)
|
||||
# check that the from pretrained methods work
|
||||
rag_token.save_pretrained(tmp_dirname)
|
||||
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
||||
rag_token.to(torch_device)
|
||||
|
||||
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
|
||||
self.assertTrue(rag_token.generator.config.max_length == 200)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_token(
|
||||
input_ids,
|
||||
|
||||
Reference in New Issue
Block a user