Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
Aymeric Augustin
2019-12-21 15:46:46 +01:00
parent 63e3827c6b
commit fa84ae26d6
200 changed files with 17452 additions and 12594 deletions

View File

@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true')
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args()
if args.model_type == 'bert':
if args.model_type == "bert":
model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert'
prefix = "bert"
else:
raise ValueError(f'args.model_type should be "bert".')
state_dict = model.state_dict()
compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \
state_dict[f'{prefix}.embeddings.{w}.weight']
for w in ['weight', 'bias']:
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
for w in ["weight", "bias"]:
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}']
for w in ["weight", "bias"]:
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}']
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
for w in ["weight", "bias"]:
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
print(f"N layers selected for distillation: {std_idx}")
print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint)