[s2s] distill t5-large -> t5-small (#8376)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a5b682329c
commit
81ebd70671
@@ -380,7 +380,7 @@ cp xsum/test* all_pl
|
|||||||
then use `all_pl` as DATA in the command above.
|
then use `all_pl` as DATA in the command above.
|
||||||
|
|
||||||
#### Direct Knowledge Distillation (KD)
|
#### Direct Knowledge Distillation (KD)
|
||||||
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `SummarizationDistiller`.
|
||||||
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
|
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
|
||||||
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
|
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
|||||||
from lightning_base import generic_train # noqa
|
from lightning_base import generic_train # noqa
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class SummarizationDistiller(SummarizationModule):
|
||||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
"""Supports T5, Bart, Pegasus and other models that inherit from Bart."""
|
||||||
|
|
||||||
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
|
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||||
|
|
||||||
@@ -40,26 +40,38 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
|
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
|
||||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
||||||
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
|
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
|
||||||
|
if hparams.student is not None:
|
||||||
|
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
|
||||||
|
use_task_specific_params(student, hparams.task)
|
||||||
|
e_layer_ids, d_layer_ids = None, None
|
||||||
|
else:
|
||||||
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
|
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
|
||||||
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
|
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
if hparams.length_penalty != -1:
|
if hparams.length_penalty != -1:
|
||||||
student.config.length_penalty = hparams.length_penalty
|
student.config.length_penalty = hparams.length_penalty
|
||||||
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
|
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
|
||||||
super().__init__(hparams, model=student, config=student.config)
|
super().__init__(hparams, model=student, config=student.config)
|
||||||
model_type = student.config.model_type
|
assert (
|
||||||
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
student.config.model_type == teacher.config.model_type
|
||||||
|
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
|
||||||
|
|
||||||
if model_type == "t5":
|
if student.config.model_type == "t5":
|
||||||
|
student_encoder_layers = len(student.get_encoder().block)
|
||||||
|
student_decoder_layers = len(student.get_decoder().block)
|
||||||
teacher_encoder_layers = len(teacher.get_encoder().block)
|
teacher_encoder_layers = len(teacher.get_encoder().block)
|
||||||
teacher_decoder_layers = len(teacher.get_decoder().block)
|
teacher_decoder_layers = len(teacher.get_decoder().block)
|
||||||
else:
|
else:
|
||||||
|
student_encoder_layers = student.config.encoder_layers
|
||||||
|
student_decoder_layers = student.config.decoder_layers
|
||||||
teacher_encoder_layers = teacher.config.encoder_layers
|
teacher_encoder_layers = teacher.config.encoder_layers
|
||||||
teacher_decoder_layers = teacher.config.decoder_layers
|
teacher_decoder_layers = teacher.config.decoder_layers
|
||||||
|
|
||||||
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
|
self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student)
|
||||||
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
|
self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0
|
||||||
|
self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers)
|
||||||
|
# self.different_encoder determines whether we need to run the teacher encoder
|
||||||
self.teacher = teacher
|
self.teacher = teacher
|
||||||
freeze_params(self.teacher)
|
freeze_params(self.teacher)
|
||||||
|
|
||||||
@@ -68,13 +80,28 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
del self.teacher.model.encoder
|
del self.teacher.model.encoder
|
||||||
except AttributeError: # T5
|
except AttributeError: # T5
|
||||||
del self.teacher.encoder
|
del self.teacher.encoder
|
||||||
# Intermediate supervision: Decide which layers to supervise
|
|
||||||
|
if e_layer_ids is None:
|
||||||
|
e_layer_ids = list(range(student_encoder_layers))
|
||||||
|
if d_layer_ids is None:
|
||||||
|
d_layer_ids = list(range(student_decoder_layers))
|
||||||
|
|
||||||
|
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
||||||
|
|
||||||
|
if self.do_calc_hidden_loss: # Intermediate supervision: Decide which layers to supervise
|
||||||
if hparams.supervise_forward:
|
if hparams.supervise_forward:
|
||||||
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
|
self.e_matches = get_layers_to_supervise(
|
||||||
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
|
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
|
||||||
|
)
|
||||||
|
self.d_matches = get_layers_to_supervise(
|
||||||
|
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
|
||||||
|
)
|
||||||
else: # student layer should emulate hidden states of the teacher layer it was copied from
|
else: # student layer should emulate hidden states of the teacher layer it was copied from
|
||||||
self.e_matches = self.e_layer_ids
|
self.e_matches = self.e_layer_ids
|
||||||
self.d_matches = self.d_layer_ids
|
self.d_matches = self.d_layer_ids
|
||||||
|
else:
|
||||||
|
self.e_matches = None
|
||||||
|
self.d_matches = None
|
||||||
|
|
||||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
self.temperature = 2.0
|
self.temperature = 2.0
|
||||||
@@ -84,22 +111,8 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
|
||||||
"""Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
|
|
||||||
# raise NotImplementedError()
|
|
||||||
if mask is not None:
|
|
||||||
# mask has False at padding_idx
|
|
||||||
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
|
||||||
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
|
|
||||||
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
|
|
||||||
else:
|
|
||||||
t_logits_slct = teacher_outputs
|
|
||||||
s_logits_slct = student_outputs
|
|
||||||
return F.mse_loss(s_logits_slct, t_logits_slct)
|
|
||||||
|
|
||||||
def calc_ce_loss(self, mask, s_logits, t_logits):
|
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||||
"""Copy pasted from distillbert (transformers/examples/distillation/)"""
|
"""Copy pasted from distillbert (transformers/examples/distillation/)"""
|
||||||
|
|
||||||
# mask has False at padding_idx
|
# mask has False at padding_idx
|
||||||
sel_mask = mask[:, :, None].expand_as(s_logits)
|
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||||
vocab_size = s_logits.size(-1)
|
vocab_size = s_logits.size(-1)
|
||||||
@@ -123,8 +136,8 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
add_distill_args(parser)
|
add_distill_args(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch: dict) -> tuple:
|
||||||
# assert is_frozen(self.teacher) copied_decoder_layers
|
"""Compute the loss for a batch"""
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
||||||
if isinstance(self.model, T5ForConditionalGeneration):
|
if isinstance(self.model, T5ForConditionalGeneration):
|
||||||
@@ -133,14 +146,16 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
|
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
|
||||||
|
|
||||||
# noinspection PyCallingNonCallable
|
# noinspection PyCallingNonCallable
|
||||||
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
student_outputs = self(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
output_hidden_states=True,
|
output_hidden_states=self.do_calc_hidden_loss,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
lm_logits = student_outputs.logits
|
||||||
|
|
||||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
@@ -149,7 +164,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
lprobs = F.log_softmax(lm_logits, dim=-1)
|
||||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||||
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
)
|
)
|
||||||
@@ -157,37 +172,44 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
def zero_tensor():
|
def zero_tensor():
|
||||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||||
|
|
||||||
|
teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
|
||||||
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||||
if self.different_encoder: # compute encoder hidden state loss
|
if self.different_encoder: # compute encoder hidden state loss
|
||||||
with torch.no_grad():
|
all_teacher_encoder_outputs = self.teacher.get_encoder()(
|
||||||
teacher_enc_hid = self.teacher.get_encoder()(
|
input_ids,
|
||||||
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
|
attention_mask=src_mask,
|
||||||
).hidden_states
|
output_hidden_states=self.do_calc_hidden_loss,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
if self.different_base_models:
|
||||||
|
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
|
||||||
|
elif self.do_calc_hidden_loss:
|
||||||
hid_loss_enc = self.calc_hidden_loss(
|
hid_loss_enc = self.calc_hidden_loss(
|
||||||
src_mask,
|
src_mask,
|
||||||
enc_hidden_state,
|
student_outputs.encoder_hidden_states,
|
||||||
teacher_enc_hid,
|
all_teacher_encoder_outputs.hidden_states,
|
||||||
self.e_matches,
|
self.e_matches,
|
||||||
normalize_hidden=self.hparams.normalize_hidden,
|
normalize_hidden=self.hparams.normalize_hidden,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
teacher_outputs = self.teacher(
|
||||||
outputs = self.teacher(
|
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
encoder_outputs=(enc_outputs,),
|
encoder_outputs=(teacher_enc_outputs,),
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
lm_labels=labels,
|
output_hidden_states=self.do_calc_hidden_loss,
|
||||||
output_hidden_states=True,
|
use_cache=False, # since we are not passing labels, never let this default to True
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
|
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
|
||||||
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
|
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
|
||||||
hid_loss_dec = self.calc_hidden_loss(
|
hid_loss_dec = self.calc_hidden_loss(
|
||||||
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
|
dec_mask,
|
||||||
|
student_outputs.decoder_hidden_states,
|
||||||
|
teacher_outputs.decoder_hidden_states,
|
||||||
|
self.d_matches,
|
||||||
|
normalize_hidden=self.hparams.normalize_hidden,
|
||||||
)
|
)
|
||||||
|
|
||||||
blended_loss = (
|
blended_loss = (
|
||||||
@@ -207,6 +229,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||||
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
||||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||||
|
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
|
||||||
if normalize_hidden:
|
if normalize_hidden:
|
||||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||||
@@ -216,10 +239,16 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
|
|
||||||
|
|
||||||
def add_distill_args(parser):
|
def add_distill_args(parser):
|
||||||
|
# NOTE: if --student argument was specified and the teacher and student base models
|
||||||
|
# are different, the models still have to have the same tokenizer, specified by
|
||||||
|
# --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not
|
||||||
|
# from bart to t5. This s because if the tokenizers are different, the output space
|
||||||
|
# for the two models is also different and their logits are not comparable.
|
||||||
parser.add_argument("--teacher", type=str)
|
parser.add_argument("--teacher", type=str)
|
||||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||||
|
parser.add_argument("--student", type=str, required=False)
|
||||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||||
@@ -228,8 +257,8 @@ def add_distill_args(parser):
|
|||||||
parser.add_argument("--normalize_hidden", action="store_true", default=False)
|
parser.add_argument("--normalize_hidden", action="store_true", default=False)
|
||||||
|
|
||||||
|
|
||||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
class TranslationDistiller(SummarizationDistiller):
|
||||||
"""Supports Mbart, Marian, other models that inherit from Bart."""
|
"""Supports T5, mBART, Marian, other models that inherit from Bart."""
|
||||||
|
|
||||||
mode = "translation"
|
mode = "translation"
|
||||||
metric_names = ["bleu"]
|
metric_names = ["bleu"]
|
||||||
@@ -258,7 +287,7 @@ def create_module(args):
|
|||||||
if args.no_teacher:
|
if args.no_teacher:
|
||||||
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
||||||
else: # DISTILL WITH TEACHER
|
else: # DISTILL WITH TEACHER
|
||||||
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
|
module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller
|
||||||
args.setup_cls: str = module_cls.__name__
|
args.setup_cls: str = module_cls.__name__
|
||||||
print(f"using module {args.setup_cls}")
|
print(f"using module {args.setup_cls}")
|
||||||
model = module_cls(args)
|
model = module_cls(args)
|
||||||
@@ -276,7 +305,7 @@ def distill_main(args):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
distill_main(args)
|
distill_main(args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import pytorch_lightning as pl
|
|||||||
import timeout_decorator
|
import timeout_decorator
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from distillation import BartSummarizationDistiller, distill_main
|
from distillation import SummarizationDistiller, distill_main
|
||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
from transformers import MarianMTModel
|
from transformers import MarianMTModel
|
||||||
from transformers.file_utils import cached_path
|
from transformers.file_utils import cached_path
|
||||||
@@ -170,7 +170,7 @@ class TestDistilMarianNoTeacher(TestCasePlus):
|
|||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# assert args.gpus == gpus THIS BREAKS for multi_gpu
|
# assert args.gpus == gpus THIS BREAKS for multi_gpu
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ CHEAP_ARGS = {
|
|||||||
"freeze_encoder": False,
|
"freeze_encoder": False,
|
||||||
"auto_scale_batch_size": False,
|
"auto_scale_batch_size": False,
|
||||||
"overwrite_output_dir": False,
|
"overwrite_output_dir": False,
|
||||||
|
"student": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -107,6 +108,7 @@ def _dump_articles(path: Path, articles: list):
|
|||||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||||
|
T5_TINIER = "sshleifer/t5-tinier-random"
|
||||||
BART_TINY = "sshleifer/bart-tiny-random"
|
BART_TINY = "sshleifer/bart-tiny-random"
|
||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||||
@@ -239,6 +241,16 @@ class TestSummarizationDistiller(TestCasePlus):
|
|||||||
)
|
)
|
||||||
self._test_distiller_cli(updates)
|
self._test_distiller_cli(updates)
|
||||||
|
|
||||||
|
@require_torch_non_multi_gpu_but_fix_me
|
||||||
|
def test_distill_different_base_models(self):
|
||||||
|
updates = dict(
|
||||||
|
teacher=T5_TINY,
|
||||||
|
student=T5_TINIER,
|
||||||
|
model_name_or_path=T5_TINIER,
|
||||||
|
tokenizer_name=T5_TINIER,
|
||||||
|
)
|
||||||
|
self._test_distiller_cli(updates)
|
||||||
|
|
||||||
def _test_distiller_cli(self, updates, check_contents=True):
|
def _test_distiller_cli(self, updates, check_contents=True):
|
||||||
default_updates = dict(
|
default_updates = dict(
|
||||||
label_smoothing=0.0,
|
label_smoothing=0.0,
|
||||||
|
|||||||
Reference in New Issue
Block a user