[s2s] command line args for faster val steps (#6833)
This commit is contained in:
@@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
|
|||||||
|
|
||||||
mode = "translation"
|
mode = "translation"
|
||||||
metric_names = ["bleu"]
|
metric_names = ["bleu"]
|
||||||
val_metric = "bleu"
|
default_val_metric = "bleu"
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
mode = "summarization"
|
mode = "summarization"
|
||||||
loss_names = ["loss"]
|
loss_names = ["loss"]
|
||||||
metric_names = ROUGE_KEYS
|
metric_names = ROUGE_KEYS
|
||||||
val_metric = "rouge2"
|
default_val_metric = "rouge2"
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||||
@@ -110,6 +110,9 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.dataset_class = (
|
self.dataset_class = (
|
||||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||||
)
|
)
|
||||||
|
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||||
|
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
||||||
|
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
@@ -301,6 +304,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
|
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
||||||
|
parser.add_argument("--val_metric", type=str, default=None, required=False)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stopping_patience",
|
"--early_stopping_patience",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule):
|
|||||||
mode = "translation"
|
mode = "translation"
|
||||||
loss_names = ["loss"]
|
loss_names = ["loss"]
|
||||||
metric_names = ["bleu"]
|
metric_names = ["bleu"]
|
||||||
val_metric = "bleu"
|
default_val_metric = "bleu"
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ logger = logging.getLogger()
|
|||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"label_smoothing": 0.2,
|
"label_smoothing": 0.2,
|
||||||
|
"eval_beams": 1,
|
||||||
|
"val_metric": None,
|
||||||
"adafactor": True,
|
"adafactor": True,
|
||||||
"early_stopping_patience": 2,
|
"early_stopping_patience": 2,
|
||||||
"logger_name": "default",
|
"logger_name": "default",
|
||||||
|
|||||||
Reference in New Issue
Block a user