[RAG] Fix rag from pretrained question encoder generator behavior (#11962)
* fix_torch_device_generate_test * remove @ * fix rag from pretrained loading * add test * uplaod * finish
This commit is contained in:
committed by
GitHub
parent
6db3a87de2
commit
43f46aa7fd
@@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel):
|
|||||||
question_encoder_pretrained_model_name_or_path: str = None,
|
question_encoder_pretrained_model_name_or_path: str = None,
|
||||||
generator_pretrained_model_name_or_path: str = None,
|
generator_pretrained_model_name_or_path: str = None,
|
||||||
retriever: RagRetriever = None,
|
retriever: RagRetriever = None,
|
||||||
*model_args,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
r"""
|
r"""
|
||||||
@@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs_question_encoder = {
|
kwargs_question_encoder = {
|
||||||
argument[len("question_question_encoder_") :]: value
|
argument[len("question_encoder_") :]: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if argument.startswith("question_encoder_")
|
if argument.startswith("question_encoder_")
|
||||||
}
|
}
|
||||||
@@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel):
|
|||||||
if "config" not in kwargs_question_encoder:
|
if "config" not in kwargs_question_encoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
|
||||||
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
|
question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
|
||||||
|
question_encoder_pretrained_model_name_or_path,
|
||||||
|
**kwargs_question_encoder,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
)
|
||||||
kwargs_question_encoder["config"] = question_encoder_config
|
kwargs_question_encoder["config"] = question_encoder_config
|
||||||
|
|
||||||
question_encoder = AutoModel.from_pretrained(
|
question_encoder = AutoModel.from_pretrained(
|
||||||
question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder
|
question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = kwargs_generator.pop("model", None)
|
generator = kwargs_generator.pop("model", None)
|
||||||
@@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel):
|
|||||||
if "config" not in kwargs_generator:
|
if "config" not in kwargs_generator:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
|
||||||
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
|
generator_config, kwargs_generator = AutoConfig.from_pretrained(
|
||||||
|
generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
kwargs_generator["config"] = generator_config
|
kwargs_generator["config"] = generator_config
|
||||||
|
|
||||||
generator = AutoModelForSeq2SeqLM.from_pretrained(
|
generator = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
|||||||
@@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
|
|||||||
"facebook/bart-large-cnn",
|
"facebook/bart-large-cnn",
|
||||||
retriever=rag_retriever,
|
retriever=rag_retriever,
|
||||||
config=rag_config,
|
config=rag_config,
|
||||||
|
question_encoder_max_length=200,
|
||||||
|
generator_max_length=200,
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
# check that the from pretrained methods work
|
# check that the from pretrained methods work
|
||||||
rag_token.save_pretrained(tmp_dirname)
|
rag_token.save_pretrained(tmp_dirname)
|
||||||
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
||||||
rag_token.to(torch_device)
|
rag_token.to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
|
||||||
|
self.assertTrue(rag_token.generator.config.max_length == 200)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = rag_token(
|
output = rag_token(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user