Fixed parameter name for logits_processor (#9790)
This commit is contained in:
@@ -1486,7 +1486,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return self.greedy_search(
|
return self.greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
pre_processor=pre_processor,
|
logits_processor=pre_processor,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
@@ -1509,7 +1509,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
return self.beam_search(
|
return self.beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
pre_processor=pre_processor,
|
logits_processor=pre_processor,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user