prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import gc
import os
import warnings
from pathlib import Path
from typing import List
@@ -11,6 +12,7 @@ from torch.nn import functional as F
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
@@ -22,6 +24,7 @@ try:
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
@@ -34,12 +37,15 @@ except ImportError:
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule):
def _step(self, batch):
# assert is_frozen(self.teacher)
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
# noinspection PyCallingNonCallable
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
)
use_cache=False,
) # TODO(@sshleifer): return_dict=True cleanup
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
return torch.tensor(0.0).type_as(student_lm_loss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule):
attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
lm_labels=tgt_ids,
output_hidden_states=True,
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
assert not isinstance(
hidden_states, torch.Tensor
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
assert not isinstance(
hidden_states_T, torch.Tensor
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
msg = "expected list or tuple for hidden_states, got tensor of shape: "
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [
@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule):
def add_distill_args(parser):
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
@@ -245,8 +258,9 @@ def add_distill_args(parser):
class BartTranslationDistiller(BartSummarizationDistiller):
"""Supports Mbart, Marian, other models that inherit from Bart."""
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
labels=labels,
output_hidden_states=True,
use_cache=False,
)
@@ -402,6 +416,7 @@ def create_module(args):
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
# TODO(SS): DELETE?
exp_dir = ckpt_path.parent
if dest_dir is None:
dest_dir = exp_dir
@@ -424,33 +439,40 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
trainer.test(model)
def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot))
if tot == 12: # Alternating for special cases
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers,
}
return layers_to_copy[n_to_get]
elif tot == 16:
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
LAYERS_TO_COPY = {
# maps num layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
def get_layers_to_copy(n_student, n_teacher):
try:
return LAYERS_TO_COPY[n_teacher][n_student]
except KeyError:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def distill_main(args):