From 0d8f8848d5de1e6f4a785484f5dbe331d6a28e2a Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 28 Aug 2019 04:00:19 +0000 Subject: [PATCH] add `scripts/extract_for_distil.py` --- .../scripts/extract_for_distil.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 examples/distillation/scripts/extract_for_distil.py diff --git a/examples/distillation/scripts/extract_for_distil.py b/examples/distillation/scripts/extract_for_distil.py new file mode 100644 index 0000000000..27266c82ea --- /dev/null +++ b/examples/distillation/scripts/extract_for_distil.py @@ -0,0 +1,59 @@ +from pytorch_transformers import BertForPreTraining +import torch +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation") + parser.add_argument("--bert_model", default='bert-base-uncased', type=str) + parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str) + parser.add_argument("--vocab_transform", action='store_true') + args = parser.parse_args() + + + model = BertForPreTraining.from_pretrained(args.bert_model) + + state_dict = model.state_dict() + compressed_sd = {} + + for w in ['word_embeddings', 'position_embeddings']: + compressed_sd[f'dilbert.embeddings.{w}.weight'] = \ + state_dict[f'bert.embeddings.{w}.weight'] + for w in ['weight', 'bias']: + compressed_sd[f'dilbert.embeddings.LayerNorm.{w}'] = \ + state_dict[f'bert.embeddings.LayerNorm.{w}'] + + std_idx = 0 + for teacher_idx in [0, 2, 4, 7, 9, 11]: + for w in ['weight', 'bias']: + compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}'] + compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}'] + compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}'] + + compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] + compressed_sd[f'dilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] + + compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] + compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}'] + compressed_sd[f'dilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ + state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] + std_idx += 1 + + compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] + compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] + if args.vocab_transform: + for w in ['weight', 'bias']: + compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] + compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] + + print(f'N layers selected for distillation: {std_idx}') + print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') + + print(f'Save transfered checkpoint to {args.dump_checkpoint}.') + torch.save(compressed_sd, args.dump_checkpoint)