[examples/seq2seq]: add --label_smoothing option (#5919)
This commit is contained in:
@@ -33,9 +33,10 @@ try:
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
@@ -52,8 +53,9 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
y_ids = y[:, :-1].contiguous()
|
||||
lm_labels = y[:, 1:].clone()
|
||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
||||
loss = outputs[0]
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
lm_logits = outputs[0]
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
return (loss,)
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
@@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule:
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
if args.early_stopping_patience >= 0:
|
||||
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
else:
|
||||
es_callback = False
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user