diff --git a/examples/distillation/scripts/extract.py b/examples/distillation/scripts/extract.py new file mode 100644 index 0000000000..5ae1607f3f --- /dev/null +++ b/examples/distillation/scripts/extract.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing script before training the distilled model. +Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. +""" +from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel +import torch +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation") + parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) + parser.add_argument("--model_name", default='roberta-large', type=str) + parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str) + parser.add_argument("--vocab_transform", action='store_true') + args = parser.parse_args() + + + if args.model_type == 'roberta': + model = RobertaForMaskedLM.from_pretrained(args.model_name) + prefix = 'roberta' + elif args.model_type == 'gpt2': + model = GPT2LMHeadModel.from_pretrained(args.model_name) + prefix = 'transformer' + + state_dict = model.state_dict() + compressed_sd = {} + + ### Embeddings ### + if args.model_type == 'gpt2': + for param_name in ['wte.weight', 'wpe.weight']: + compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}'] + else: + for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']: + param_name = f'{prefix}.embeddings.{w}.weight' + compressed_sd[param_name] = state_dict[param_name] + for w in ['weight', 'bias']: + param_name = f'{prefix}.embeddings.LayerNorm.{w}' + compressed_sd[param_name] = state_dict[param_name] + + ### Transformer Blocks ### + std_idx = 0 + for teacher_idx in [0, 2, 4, 7, 9, 11]: + if args.model_type == 'gpt2': + for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']: + for w in ['weight', 'bias']: + compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \ + state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}'] + compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias'] + else: + for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value', + 'attention.output.dense', 'attention.output.LayerNorm', + 'intermediate.dense', 'output.dense', 'output.LayerNorm']: + for w in ['weight', 'bias']: + compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \ + state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}'] + std_idx += 1 + + ### Language Modeling Head ###s + if args.model_type == 'roberta': + for layer in ['lm_head.decoder.weight', 'lm_head.bias']: + compressed_sd[f'{layer}'] = state_dict[f'{layer}'] + if args.vocab_transform: + for w in ['weight', 'bias']: + compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}'] + compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] + elif args.model_type == 'gpt2': + for w in ['weight', 'bias']: + compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}'] + compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight'] + + print(f'N layers selected for distillation: {std_idx}') + print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') + + print(f'Save transfered checkpoint to {args.dump_checkpoint}.') + torch.save(compressed_sd, args.dump_checkpoint)