[s2s] add create student script (#7290)
Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -369,7 +369,7 @@ runtime: 6H on NVIDIA RTX 24GB GPU
|
|||||||
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
||||||
because you will have the same hyperparameters logged in every run.
|
because you will have the same hyperparameters logged in every run.
|
||||||
|
|
||||||
#### With a teacher
|
#### With a teacher (Intermediate Supervision)
|
||||||
*Note* only BART variants are supported
|
*Note* only BART variants are supported
|
||||||
|
|
||||||
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
||||||
@@ -378,7 +378,7 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
|
|||||||
The command that produced `sshleifer/distilbart-xsum-12-6` is:
|
The command that produced `sshleifer/distilbart-xsum-12-6` is:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./train_distilbart_xsum.sh
|
./train_distilbart_xsum.sh --logger_name wandb --gpus 1
|
||||||
```
|
```
|
||||||
|
|
||||||
runtime: 13H on V-100 16GB GPU.
|
runtime: 13H on V-100 16GB GPU.
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -15,18 +14,10 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from finetune import SummarizationModule, TranslationModule
|
from finetune import SummarizationModule, TranslationModule
|
||||||
from finetune import main as ft_main
|
from finetune import main as ft_main
|
||||||
from initialization_utils import copy_layers, init_student
|
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from utils import (
|
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params
|
||||||
any_requires_grad,
|
|
||||||
assert_all_frozen,
|
|
||||||
calculate_bleu,
|
|
||||||
freeze_params,
|
|
||||||
label_smoothed_nll_loss,
|
|
||||||
pickle_load,
|
|
||||||
use_task_specific_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# need the parent dir module
|
# need the parent dir module
|
||||||
@@ -41,87 +32,50 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
assert Path(hparams.data_dir).exists()
|
assert Path(hparams.data_dir).exists()
|
||||||
student, student_cfg, teacher = self.pre_init(hparams)
|
self.output_dir = Path(hparams.output_dir)
|
||||||
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
super().__init__(hparams, model=student, config=student_cfg)
|
save_dir = self.output_dir.joinpath("student")
|
||||||
|
|
||||||
|
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
|
||||||
|
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
||||||
|
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
|
||||||
|
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
|
||||||
|
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
|
||||||
|
)
|
||||||
|
if hparams.length_penalty != -1:
|
||||||
|
student.config.length_penalty = hparams.length_penalty
|
||||||
|
super().__init__(hparams, model=student, config=student.config)
|
||||||
|
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
||||||
|
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
|
||||||
|
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
||||||
self.teacher = teacher
|
self.teacher = teacher
|
||||||
use_task_specific_params(self.teacher, "summarization")
|
|
||||||
freeze_params(self.teacher)
|
freeze_params(self.teacher)
|
||||||
self.sanity_check_gradients()
|
|
||||||
|
if not self.different_encoder: # To save RAM, delete teacher encoder and freeze student encoder.
|
||||||
|
try:
|
||||||
|
del self.teacher.model.encoder
|
||||||
|
except AttributeError: # T5
|
||||||
|
del self.teacher.encoder
|
||||||
|
# Intermediate supervision: Decide which layers to supervise
|
||||||
|
if hparams.supervise_forward:
|
||||||
|
self.d_matches = get_layers_to_supervise(
|
||||||
|
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.d_matches = self.d_layer_ids
|
||||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
self.temperature = 2.0
|
self.temperature = 2.0
|
||||||
self.alpha_mlm = hparams.alpha_mlm
|
self.alpha_mlm = hparams.alpha_mlm
|
||||||
self.alpha_ce = hparams.alpha_ce
|
self.alpha_ce = hparams.alpha_ce
|
||||||
self.alpha_hid = hparams.alpha_hid
|
self.alpha_hid = hparams.alpha_hid
|
||||||
# self.alpha_cos = hparams.alpha_cos
|
self.alpha_encoder_loss = hparams.alpha_encoder_loss
|
||||||
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def sanity_check_gradients(self):
|
|
||||||
assert_all_frozen(self.teacher)
|
|
||||||
assert_all_frozen(self.model.model.decoder.embed_tokens)
|
|
||||||
assert_all_frozen(self.model.model.encoder.embed_tokens)
|
|
||||||
if self.different_encoder:
|
|
||||||
assert any_requires_grad(self.model.model.encoder)
|
|
||||||
else:
|
|
||||||
freeze_params(self.model.model.encoder)
|
|
||||||
del self.teacher.model.encoder
|
|
||||||
|
|
||||||
def pre_init(self, hparams):
|
|
||||||
self.output_dir = Path(hparams.output_dir)
|
|
||||||
self.output_dir.mkdir(exist_ok=True)
|
|
||||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
|
||||||
student_updates = {
|
|
||||||
"decoder_layers": hparams.student_decoder_layers,
|
|
||||||
"encoder_layers": hparams.student_encoder_layers,
|
|
||||||
}
|
|
||||||
if hparams.length_penalty != -1:
|
|
||||||
student_updates["length_penalty"] = hparams.length_penalty
|
|
||||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
|
||||||
hparams.e_layer_to_copy = e_layers_to_copy
|
|
||||||
|
|
||||||
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
|
||||||
|
|
||||||
if hparams.supervise_forward:
|
|
||||||
hparams.d_matches = get_layers_to_supervise(
|
|
||||||
student_updates["decoder_layers"], teacher.config.decoder_layers
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hparams.d_matches = d_layers_to_copy
|
|
||||||
hparams.d_layer_to_copy = d_layers_to_copy
|
|
||||||
|
|
||||||
kw = teacher.config.to_diff_dict()
|
|
||||||
kw.update(student_updates)
|
|
||||||
# Copy weights
|
|
||||||
student_cfg = teacher.config_class(**kw)
|
|
||||||
student = type(teacher)(student_cfg)
|
|
||||||
student, _ = init_student(student, teacher)
|
|
||||||
save_dir = self.output_dir.joinpath("student")
|
|
||||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
|
||||||
student.save_pretrained(save_dir)
|
|
||||||
hparams.model_name_or_path = str(save_dir)
|
|
||||||
return student, student_cfg, teacher
|
|
||||||
|
|
||||||
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
|
||||||
if teacher.config.model_type == "t5":
|
|
||||||
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
|
||||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
|
|
||||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
|
||||||
if self.different_decoder:
|
|
||||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
|
||||||
if self.different_encoder:
|
|
||||||
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
|
||||||
|
|
||||||
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
|
||||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
|
|
||||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
|
|
||||||
if self.different_decoder:
|
|
||||||
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
|
||||||
if self.different_encoder:
|
|
||||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
|
||||||
|
|
||||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||||
|
"""Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
|
||||||
|
# raise NotImplementedError()
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# mask has False at padding_idx
|
# mask has False at padding_idx
|
||||||
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
||||||
@@ -133,20 +87,15 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
return F.mse_loss(s_logits_slct, t_logits_slct)
|
return F.mse_loss(s_logits_slct, t_logits_slct)
|
||||||
|
|
||||||
def calc_ce_loss(self, mask, s_logits, t_logits):
|
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||||
if mask is not None:
|
"""Copy pasted from distillbert (transformers/examples/distillation/)"""
|
||||||
# mask has False at padding_idx
|
|
||||||
sel_mask = mask[:, :, None].expand_as(s_logits)
|
# mask has False at padding_idx
|
||||||
s_logits_slct = torch.masked_select(
|
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||||
s_logits, sel_mask
|
vocab_size = s_logits.size(-1)
|
||||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
t_logits_slct = torch.masked_select(
|
t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
t_logits, sel_mask
|
s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
else:
|
|
||||||
t_logits_slct = t_logits
|
|
||||||
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
|
|
||||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
|
||||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
|
||||||
assert t_logits_slct.size() == s_logits_slct.size()
|
assert t_logits_slct.size() == s_logits_slct.size()
|
||||||
loss_ce = (
|
loss_ce = (
|
||||||
self.ce_loss_fct(
|
self.ce_loss_fct(
|
||||||
@@ -155,7 +104,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
)
|
)
|
||||||
* (self.temperature) ** 2
|
* (self.temperature) ** 2
|
||||||
)
|
)
|
||||||
return loss_ce, s_logits_slct, t_logits_slct
|
return loss_ce
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
@@ -164,10 +113,14 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch):
|
||||||
# assert is_frozen(self.teacher)
|
# assert is_frozen(self.teacher) copied_decoder_layers
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
||||||
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
if isinstance(self.model, T5ForConditionalGeneration):
|
||||||
|
decoder_input_ids = self.model._shift_right(labels)
|
||||||
|
else:
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
|
||||||
|
|
||||||
# noinspection PyCallingNonCallable
|
# noinspection PyCallingNonCallable
|
||||||
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -183,11 +136,11 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def zero_tensor():
|
def zero_tensor():
|
||||||
@@ -196,15 +149,14 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||||
if self.different_encoder:
|
if self.different_encoder:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
|
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
|
||||||
input_ids, attention_mask=src_mask, output_hidden_states=True
|
input_ids, attention_mask=src_mask, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
# DEPRECATE THIS
|
||||||
if self.hparams.alpha_encoder_loss > 0:
|
if self.hparams.alpha_encoder_loss > 0:
|
||||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
||||||
|
|
||||||
hid_loss_enc = self.calc_hidden_loss(
|
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
|
||||||
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
|
||||||
)
|
|
||||||
|
|
||||||
teacher_enc_outputs = (enc_outputs,)
|
teacher_enc_outputs = (enc_outputs,)
|
||||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||||
@@ -215,13 +167,15 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
encoder_outputs=teacher_enc_outputs,
|
encoder_outputs=teacher_enc_outputs,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
lm_labels=tgt_ids,
|
lm_labels=labels,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||||
if self.alpha_hid > 0:
|
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
|
||||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
hid_loss_dec = self.calc_hidden_loss(
|
||||||
|
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
|
||||||
|
)
|
||||||
|
|
||||||
blended_loss = (
|
blended_loss = (
|
||||||
self.alpha_ce * loss_ce
|
self.alpha_ce * loss_ce
|
||||||
@@ -231,7 +185,9 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
)
|
)
|
||||||
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||||
|
|
||||||
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
@staticmethod
|
||||||
|
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
|
||||||
|
"""MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
|
||||||
msg = "expected list or tuple for hidden_states, got tensor of shape: "
|
msg = "expected list or tuple for hidden_states, got tensor of shape: "
|
||||||
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
|
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
|
||||||
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||||
@@ -239,7 +195,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||||
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
||||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||||
if self.hparams.normalize_hidden:
|
if normalize_hidden:
|
||||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||||
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
||||||
@@ -287,130 +243,9 @@ class BartTranslationDistiller(BartSummarizationDistiller):
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class T5SummarizationDistiller(BartSummarizationDistiller):
|
|
||||||
def pre_init(self, hparams):
|
|
||||||
raise NotImplementedError("T5 Distillation does not work yet")
|
|
||||||
self.output_dir = Path(hparams.output_dir)
|
|
||||||
self.output_dir.mkdir(exist_ok=True)
|
|
||||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
|
||||||
n_layer = hparams.student_decoder_layers
|
|
||||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
|
|
||||||
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
|
||||||
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
|
||||||
student_updates = {"num_layers": n_layer}
|
|
||||||
hparams.d_layer_to_copy = d_layers_to_copy
|
|
||||||
hparams.e_layer_to_copy = e_layers_to_copy
|
|
||||||
kw = teacher.config.to_diff_dict()
|
|
||||||
|
|
||||||
kw.update(student_updates)
|
|
||||||
# Copy weights
|
|
||||||
student_cfg = T5Config(**kw)
|
|
||||||
student = T5ForConditionalGeneration(student_cfg)
|
|
||||||
student, _ = init_student(student, teacher)
|
|
||||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
|
||||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
|
||||||
task_specific_params = student.config.task_specific_params
|
|
||||||
if task_specific_params is not None:
|
|
||||||
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
|
|
||||||
save_dir = self.output_dir.joinpath("student")
|
|
||||||
save_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
student.save_pretrained(save_dir)
|
|
||||||
hparams.model_name_or_path = str(save_dir)
|
|
||||||
return student, student_cfg, teacher
|
|
||||||
|
|
||||||
def freeze_embeds(self):
|
|
||||||
freeze_params(self.model.shared)
|
|
||||||
for d in [self.model.encoder, self.model.decoder]:
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
|
|
||||||
def sanity_check_gradients(self):
|
|
||||||
"""T5"""
|
|
||||||
assert_all_frozen(self.teacher)
|
|
||||||
assert_all_frozen(self.model.decoder.embed_tokens)
|
|
||||||
assert_all_frozen(self.model.encoder.embed_tokens)
|
|
||||||
if self.different_encoder:
|
|
||||||
assert any_requires_grad(self.model.encoder)
|
|
||||||
else:
|
|
||||||
freeze_params(self.model.encoder)
|
|
||||||
del self.teacher.model.encoder
|
|
||||||
if self.different_decoder:
|
|
||||||
assert any_requires_grad(self.model.decoder)
|
|
||||||
else:
|
|
||||||
freeze_params(self.model.decoder) # TODO(SS): very suspicious
|
|
||||||
|
|
||||||
def _step(self, batch):
|
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
|
||||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
|
||||||
decoder_input_ids = y[:, :-1].contiguous()
|
|
||||||
labels = y[:, 1:].clone()
|
|
||||||
labels[y[:, 1:] == pad_token_id] = -100
|
|
||||||
# noinspection PyCallingNonCallable
|
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
|
||||||
|
|
||||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
|
||||||
source_ids,
|
|
||||||
attention_mask=source_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
labels=labels,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=False,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_tensor():
|
|
||||||
return torch.tensor(0.0).type_as(sloss)
|
|
||||||
|
|
||||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
|
||||||
if self.different_encoder:
|
|
||||||
with torch.no_grad():
|
|
||||||
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
|
||||||
source_ids,
|
|
||||||
attention_mask=source_mask,
|
|
||||||
output_hidden_states=True,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
if self.hparams.alpha_encoder_loss > 0:
|
|
||||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
|
||||||
|
|
||||||
hid_loss_enc = self.calc_hidden_loss(
|
|
||||||
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
|
||||||
)
|
|
||||||
|
|
||||||
teacher_enc_outputs = (enc_outputs,)
|
|
||||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
|
||||||
source_ids,
|
|
||||||
attention_mask=source_mask,
|
|
||||||
encoder_outputs=teacher_enc_outputs,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
labels=labels,
|
|
||||||
output_hidden_states=True,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
|
||||||
if self.alpha_hid > 0:
|
|
||||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
|
|
||||||
|
|
||||||
blended_loss = (
|
|
||||||
self.alpha_ce * loss_ce
|
|
||||||
+ self.alpha_mlm * sloss
|
|
||||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
|
||||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
|
||||||
)
|
|
||||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
|
||||||
|
|
||||||
|
|
||||||
def create_module(args):
|
def create_module(args):
|
||||||
t5 = "t5" in args.model_name_or_path
|
|
||||||
if args.no_teacher:
|
if args.no_teacher:
|
||||||
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
||||||
elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION
|
|
||||||
assert "translation" not in args.task, "t5 translation distillation not supported"
|
|
||||||
module_cls = T5SummarizationDistiller
|
|
||||||
else: # DISTILL WITH TEACHER
|
else: # DISTILL WITH TEACHER
|
||||||
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
|
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
|
||||||
args.setup_cls: str = module_cls.__name__
|
args.setup_cls: str = module_cls.__name__
|
||||||
@@ -443,56 +278,6 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
|||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
|
|
||||||
LAYERS_TO_COPY = {
|
|
||||||
# maps num layers in student -> which teacher layers to copy.
|
|
||||||
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
|
|
||||||
12: {
|
|
||||||
1: [0],
|
|
||||||
2: [0, 6],
|
|
||||||
3: [0, 6, 11],
|
|
||||||
4: [0, 4, 8, 11],
|
|
||||||
6: [0, 2, 4, 7, 9, 11],
|
|
||||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
|
||||||
12: list(range(12)),
|
|
||||||
},
|
|
||||||
16: { # maps num layers in student -> which teacher layers to copy
|
|
||||||
1: [0],
|
|
||||||
2: [0, 8],
|
|
||||||
3: [0, 8, 15],
|
|
||||||
4: [0, 5, 10, 15],
|
|
||||||
6: [0, 3, 6, 9, 12, 15],
|
|
||||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
|
||||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
|
||||||
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
|
|
||||||
16: list(range(16)),
|
|
||||||
},
|
|
||||||
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
|
||||||
}
|
|
||||||
LAYERS_TO_SUPERVISE = {
|
|
||||||
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
|
|
||||||
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
|
|
||||||
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
|
|
||||||
2: {1: [1], 2: [0, 1]},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_layers_to_supervise(n_student, n_teacher):
|
|
||||||
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
|
|
||||||
|
|
||||||
|
|
||||||
def get_layers_to_copy(n_student, n_teacher):
|
|
||||||
try:
|
|
||||||
val = LAYERS_TO_COPY[n_teacher][n_student]
|
|
||||||
assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student
|
|
||||||
return val
|
|
||||||
except KeyError:
|
|
||||||
if n_student != n_teacher:
|
|
||||||
warnings.warn(
|
|
||||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
|
||||||
)
|
|
||||||
return list(range(n_student))
|
|
||||||
|
|
||||||
|
|
||||||
def distill_main(args):
|
def distill_main(args):
|
||||||
Path(args.output_dir).mkdir(exist_ok=True)
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
def init_student(student, teacher):
|
|
||||||
teacher_state_dict = teacher.state_dict()
|
|
||||||
info = student.load_state_dict(teacher_state_dict, strict=False)
|
|
||||||
assert info.missing_keys == [], info.missing_keys
|
|
||||||
return student, info
|
|
||||||
|
|
||||||
|
|
||||||
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
|
|
||||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
|
|
||||||
|
|
||||||
|
|
||||||
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
|
|
||||||
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
|
|
||||||
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
|
|
||||||
student_layers.load_state_dict(layers_to_copy.state_dict())
|
|
||||||
167
examples/seq2seq/make_student.py
Normal file
167
examples/seq2seq/make_student.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
|
||||||
|
layers_to_copy = nn.ModuleList([l for i, l in enumerate(src_layers) if i in layers_to_copy])
|
||||||
|
assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
|
||||||
|
dest_layers.load_state_dict(layers_to_copy.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
LAYERS_TO_COPY = {
|
||||||
|
# maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
|
||||||
|
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
|
||||||
|
12: {
|
||||||
|
1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher
|
||||||
|
2: [0, 6],
|
||||||
|
3: [0, 6, 11],
|
||||||
|
4: [0, 4, 8, 11],
|
||||||
|
6: [0, 2, 4, 7, 9, 11],
|
||||||
|
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||||
|
12: list(range(12)),
|
||||||
|
},
|
||||||
|
16: { # maps num layers in student -> which teacher layers to copy
|
||||||
|
1: [0],
|
||||||
|
2: [0, 8],
|
||||||
|
3: [0, 8, 15],
|
||||||
|
4: [0, 5, 10, 15],
|
||||||
|
6: [0, 3, 6, 9, 12, 15],
|
||||||
|
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||||
|
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||||
|
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
|
||||||
|
16: list(range(16)),
|
||||||
|
},
|
||||||
|
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||||
|
}
|
||||||
|
LAYERS_TO_SUPERVISE = {
|
||||||
|
# maps num layers in student -> which teacher layers to copy.
|
||||||
|
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
|
||||||
|
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
|
||||||
|
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def pick_layers_to_copy(n_student, n_teacher):
|
||||||
|
try:
|
||||||
|
val = LAYERS_TO_COPY[n_teacher][n_student]
|
||||||
|
return val
|
||||||
|
except KeyError:
|
||||||
|
if n_student != n_teacher:
|
||||||
|
warnings.warn(
|
||||||
|
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||||
|
)
|
||||||
|
return list(range(n_student))
|
||||||
|
|
||||||
|
|
||||||
|
def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
|
||||||
|
"""Used or the --supervise_forward kwarg"""
|
||||||
|
if n_student > n_teacher:
|
||||||
|
raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
|
||||||
|
elif n_teacher == n_student:
|
||||||
|
return list(range(n_teacher))
|
||||||
|
elif n_student == 1:
|
||||||
|
return [n_teacher - 1]
|
||||||
|
else:
|
||||||
|
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
|
||||||
|
|
||||||
|
|
||||||
|
def create_student_by_copying_alternating_layers(
|
||||||
|
teacher: Union[str, PreTrainedModel],
|
||||||
|
save_path: Union[str, Path] = "student",
|
||||||
|
e: Union[int, None] = None,
|
||||||
|
d: Union[int, None] = None,
|
||||||
|
copy_first_teacher_layers=False,
|
||||||
|
**extra_config_kwargs
|
||||||
|
) -> Tuple[PreTrainedModel, List[int], List[int]]:
|
||||||
|
"""Make a student by copying alternating layers from a teacher, save it to save_path.
|
||||||
|
Args:
|
||||||
|
teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before
|
||||||
|
copying layers
|
||||||
|
save_path: where to save the student, defaults to student directory.
|
||||||
|
e: how many Encoder layers should the student have, default is fully copy of teacher
|
||||||
|
d: how many Decoder layers should the student have, default is fully copy of teacher
|
||||||
|
copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
|
||||||
|
**extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
student: new, smaller model. (Also saves it to save_path)
|
||||||
|
e_layers_to_copy: list of which teacher encoder layers were used
|
||||||
|
d_layers_to_copy: list of which teacher decoder layers were used
|
||||||
|
"""
|
||||||
|
_msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher."
|
||||||
|
assert (e is not None) or (d is not None), _msg
|
||||||
|
if isinstance(teacher, str):
|
||||||
|
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
|
||||||
|
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
|
||||||
|
else:
|
||||||
|
|
||||||
|
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
|
||||||
|
init_kwargs = teacher.config.to_diff_dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
|
||||||
|
if e is None:
|
||||||
|
e = teacher_e
|
||||||
|
if d is None:
|
||||||
|
d = teacher_d
|
||||||
|
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
||||||
|
except AttributeError: # T5
|
||||||
|
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers
|
||||||
|
assert e == d, "T5 Students must be symmetric"
|
||||||
|
init_kwargs["num_layers"] = e
|
||||||
|
|
||||||
|
# Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||||
|
|
||||||
|
init_kwargs.update(extra_config_kwargs)
|
||||||
|
|
||||||
|
# Copy weights
|
||||||
|
student_cfg = teacher.config_class(**init_kwargs)
|
||||||
|
student = AutoModelForSeq2SeqLM.from_config(student_cfg)
|
||||||
|
# Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
|
||||||
|
info = student.load_state_dict(teacher.state_dict(), strict=False)
|
||||||
|
assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
|
||||||
|
|
||||||
|
if copy_first_teacher_layers: # Our copying is done. We just log and save
|
||||||
|
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
|
||||||
|
logger.info(
|
||||||
|
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||||
|
)
|
||||||
|
student.save_pretrained(save_path)
|
||||||
|
return student, e_layers_to_copy, d_layers_to_copy
|
||||||
|
|
||||||
|
# Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
|
||||||
|
e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
|
||||||
|
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
|
||||||
|
|
||||||
|
try:
|
||||||
|
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||||
|
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||||
|
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
|
||||||
|
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||||
|
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||||
|
logger.info(
|
||||||
|
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||||
|
)
|
||||||
|
student.config.init_metadata = dict(
|
||||||
|
teacher_type=teacher.config.model_type,
|
||||||
|
copied_encoder_layers=e_layers_to_copy,
|
||||||
|
copied_decoder_layers=d_layers_to_copy,
|
||||||
|
)
|
||||||
|
student.save_pretrained(save_path)
|
||||||
|
# Save information about copying for easier reproducibility
|
||||||
|
|
||||||
|
return student, e_layers_to_copy, d_layers_to_copy
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(create_student_by_copying_alternating_layers)
|
||||||
26
examples/seq2seq/save_randomly_initialized_model.py
Executable file
26
examples/seq2seq/save_randomly_initialized_model.py
Executable file
@@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def save_randomly_initialized_version(config_name: str, save_dir: str, **config_kwargs):
|
||||||
|
"""Save a randomly initialized version of a model using a pretrained config.
|
||||||
|
Args:
|
||||||
|
config_name: which config to use
|
||||||
|
save_dir: where to save the resulting model and tokenizer
|
||||||
|
config_kwargs: Passed to AutoConfig
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
save_randomly_initialized_version("facebook/bart-large-cnn", "distilbart_random_cnn_6_3", encoder_layers=6, decoder_layers=3, num_beams=3)
|
||||||
|
"""
|
||||||
|
cfg = AutoConfig.from_pretrained(config_name, **config_kwargs)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_config(cfg)
|
||||||
|
model.save_pretrained(save_dir)
|
||||||
|
AutoTokenizer.from_pretrained(config_name).save_pretrained(save_dir)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(save_randomly_initialized_version)
|
||||||
Binary file not shown.
Binary file not shown.
41
examples/seq2seq/test_make_student.py
Normal file
41
examples/seq2seq/test_make_student.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from make_student import create_student_by_copying_alternating_layers
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
|
|
||||||
|
TINY_BART = "sshleifer/bart-tiny-random"
|
||||||
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MakeStudentTester(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def teacher_config(self):
|
||||||
|
return AutoConfig.from_pretrained(TINY_BART)
|
||||||
|
|
||||||
|
def test_valid_t5(self):
|
||||||
|
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
|
||||||
|
self.assertEqual(student.config.num_hidden_layers, 1)
|
||||||
|
|
||||||
|
def test_invalid_t5(self):
|
||||||
|
# T5 students must have the same e==d because there is only one config property
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
||||||
|
|
||||||
|
def test_same_decoder_small_encoder(self):
|
||||||
|
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
|
||||||
|
self.assertEqual(student.config.encoder_layers, 1)
|
||||||
|
self.assertEqual(student.config.decoder_layers, self.teacher_config.encoder_layers)
|
||||||
|
|
||||||
|
def test_small_enc_small_dec(self):
|
||||||
|
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=1)
|
||||||
|
self.assertEqual(student.config.encoder_layers, 1)
|
||||||
|
self.assertEqual(student.config.decoder_layers, 1)
|
||||||
|
|
||||||
|
def test_raises_assert(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=None, d=None)
|
||||||
@@ -1,21 +1,20 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
export BS=16
|
|
||||||
export GAS=2
|
|
||||||
python distillation.py \
|
python distillation.py \
|
||||||
|
--teacher facebook/bart-large-xsum --data_dir xsum \
|
||||||
|
--student_decoder_layers 6 --student_encoder_layers 12 \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
--learning_rate=3e-4 \
|
--learning_rate=3e-4 \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--fp16 \
|
--fp16 --fp16_opt_level=O1 \
|
||||||
--val_check_interval 0.1 --n_val 1000 \
|
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
|
||||||
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
|
|
||||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||||
--student_decoder_layers 6 --student_encoder_layers 12 \
|
|
||||||
--freeze_encoder --freeze_embeds \
|
|
||||||
--model_name_or_path IGNORED \
|
--model_name_or_path IGNORED \
|
||||||
--alpha_hid=3. --length_penalty=0.5 \
|
--alpha_hid=3. \
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
--train_batch_size=16 --eval_batch_size=16 --gradient_accumulation_steps=2 \
|
||||||
--tokenizer_name facebook/bart-large \
|
--sortish_sampler \
|
||||||
|
--num_train_epochs=6 \
|
||||||
--warmup_steps 500 \
|
--warmup_steps 500 \
|
||||||
--output_dir distilbart_xsum_12_6 \
|
--output_dir distilbart_xsum_12_6 \
|
||||||
"$@"
|
"$@"
|
||||||
|
|||||||
Reference in New Issue
Block a user