This reverts commit 5aa361f3e5.
This commit is contained in:
committed by
GitHub
parent
66e9608bae
commit
8f07f5c44b
@@ -113,10 +113,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||||
else:
|
else:
|
||||||
self.eval_max_length = self.model.config.max_length
|
self.eval_max_length = self.model.config.max_length
|
||||||
if self.hparams.eval_min_gen_length is not None:
|
|
||||||
self.eval_min_length = self.hparams.eval_min_gen_length
|
|
||||||
else:
|
|
||||||
self.eval_min_length = self.model.config.min_length
|
|
||||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||||
|
|
||||||
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
||||||
@@ -223,7 +219,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
decoder_start_token_id=self.decoder_start_token_id,
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
num_beams=self.eval_beams,
|
num_beams=self.eval_beams,
|
||||||
max_length=self.eval_max_length,
|
max_length=self.eval_max_length,
|
||||||
min_length=self.eval_min_length,
|
|
||||||
)
|
)
|
||||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||||
@@ -351,7 +346,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
||||||
)
|
)
|
||||||
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
||||||
parser.add_argument("--eval_min_gen_length", type=int, default=None, help="never generate shorter than n tokens")
|
|
||||||
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stopping_patience",
|
"--early_stopping_patience",
|
||||||
|
|||||||
Reference in New Issue
Block a user