[examples/seq2seq]: add --label_smoothing option (#5919)

This commit is contained in:
Sam Shleifer
2020-07-21 16:51:39 -04:00
committed by GitHub
parent 95d1962b9c
commit 5b193b39b0
7 changed files with 132 additions and 46 deletions

View File

@@ -33,9 +33,10 @@ try:
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
label_smoothed_nll_loss,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
except ImportError:
from utils import (
Seq2SeqDataset,
@@ -52,8 +53,9 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
label_smoothed_nll_loss,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
logger = logging.getLogger(__name__)
@@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
loss = outputs[0]
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)
return (loss,)
def training_step(self, batch, batch_idx) -> Dict:
@@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer):
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
return parser
@@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=dataset)
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
if args.early_stopping_patience >= 0:
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
else:
es_callback = False
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
# TODO: early stopping callback seems messed up
)