Fixed parameter name for logits_processor (#9790)

This commit is contained in:
Michael Glass
2021-01-26 10:44:02 -05:00
committed by GitHub
parent 0d0efd3a0e
commit c37dcff764

View File

@@ -1486,7 +1486,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
)
return self.greedy_search(
input_ids,
pre_processor=pre_processor,
logits_processor=pre_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@@ -1509,7 +1509,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self.beam_search(
input_ids,
beam_scorer,
pre_processor=pre_processor,
logits_processor=pre_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,