commplying with isort

This commit is contained in:
Victor SANH
2020-05-28 00:26:39 -04:00
parent db2a3b2e01
commit 5c8e5b3709
9 changed files with 29 additions and 28 deletions

View File

@@ -17,13 +17,13 @@ For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceC
as a standard :class:`~transformers.BertForSequenceClassification`.
"""
import argparse
import os
import shutil
import argparse
import torch
from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
def main(args):
@@ -40,13 +40,13 @@ def main(args):
for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
elif "bias" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
print(f"Copied layer {name}")
else:
if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)