From 23edebc0797008f0525fd1eef7f1299b513457ad Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 2 Oct 2019 11:01:33 -0400 Subject: [PATCH] update extract_distilbert --- ...ct_for_distil.py => extract_distilbert.py} | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) rename examples/distillation/scripts/{extract_for_distil.py => extract_distilbert.py} (76%) diff --git a/examples/distillation/scripts/extract_for_distil.py b/examples/distillation/scripts/extract_distilbert.py similarity index 76% rename from examples/distillation/scripts/extract_for_distil.py rename to examples/distillation/scripts/extract_distilbert.py index 2e7e5c73d8..fdb0662ca7 100644 --- a/examples/distillation/scripts/extract_for_distil.py +++ b/examples/distillation/scripts/extract_distilbert.py @@ -14,6 +14,7 @@ # limitations under the License. """ Preprocessing script before training DistilBERT. +Specific to BERT -> DistilBERT. """ from transformers import BertForMaskedLM, RobertaForMaskedLM import torch @@ -21,7 +22,7 @@ import argparse if __name__ == '__main__': parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") - parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) + parser.add_argument("--model_type", default="bert", choices=["bert"]) parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--vocab_transform", action='store_true') @@ -31,9 +32,8 @@ if __name__ == '__main__': if args.model_type == 'bert': model = BertForMaskedLM.from_pretrained(args.model_name) prefix = 'bert' - elif args.model_type == 'roberta': - model = RobertaForMaskedLM.from_pretrained(args.model_name) - prefix = 'roberta' + else: + raise ValueError(f'args.model_type should be "bert".') state_dict = model.state_dict() compressed_sd = {} @@ -68,20 +68,12 @@ if __name__ == '__main__': state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] std_idx += 1 - if args.model_type == 'bert': - compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] - compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] - if args.vocab_transform: - for w in ['weight', 'bias']: - compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] - compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] - elif args.model_type == 'roberta': - compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight'] - compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias'] - if args.vocab_transform: - for w in ['weight', 'bias']: - compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}'] - compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] + compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] + compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] + if args.vocab_transform: + for w in ['weight', 'bias']: + compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] + compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] print(f'N layers selected for distillation: {std_idx}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')