Fixed parameter name for logits_processor (#9790)
This commit is contained in:
@@ -1486,7 +1486,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
)
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
pre_processor=pre_processor,
|
||||
logits_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@@ -1509,7 +1509,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
pre_processor=pre_processor,
|
||||
logits_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
|
||||
Reference in New Issue
Block a user