[s2s] use --eval_beams command line arg (#6926)
This commit is contained in:
@@ -202,6 +202,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
attention_mask=batch["attention_mask"],
|
attention_mask=batch["attention_mask"],
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
decoder_start_token_id=self.decoder_start_token_id,
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
num_beams=self.eval_beams,
|
||||||
)
|
)
|
||||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user