diff --git a/examples/distillation/scripts/extract_distilbert.py b/examples/distillation/scripts/extract_distilbert.py index d709268cf0..15b48802fb 100644 --- a/examples/distillation/scripts/extract_distilbert.py +++ b/examples/distillation/scripts/extract_distilbert.py @@ -82,8 +82,8 @@ if __name__ == "__main__": compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"] if args.vocab_transform: for w in ["weight", "bias"]: - compressed_sd[f"vocab_transform.{w}"] = state_dict["cls.predictions.transform.dense.{w}"] - compressed_sd[f"vocab_layer_norm.{w}"] = state_dict["cls.predictions.transform.LayerNorm.{w}"] + compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"] + compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"] print(f"N layers selected for distillation: {std_idx}") print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")