[Generate Test] fix greedy generate test (#8293)
* fix greedy generate test * delet ipdb
This commit is contained in:
committed by
GitHub
parent
734afa37f6
commit
cb966e640b
@@ -140,10 +140,6 @@ class GenerationTesterMixin:
|
|||||||
# check `generate()` and `greedy_search()` are equal
|
# check `generate()` and `greedy_search()` are equal
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
|
||||||
model, input_ids, attention_mask
|
|
||||||
)
|
|
||||||
kwargs["encoder_outputs"] = encoder_outputs
|
|
||||||
max_length = 4
|
max_length = 4
|
||||||
|
|
||||||
output_ids_generate = model.generate(
|
output_ids_generate = model.generate(
|
||||||
@@ -154,6 +150,13 @@ class GenerationTesterMixin:
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||||
|
model, input_ids, attention_mask
|
||||||
|
)
|
||||||
|
kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_ids_greedy = model.greedy_search(
|
output_ids_greedy = model.greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user