commplying with isort
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
from .modules import *
|
||||
|
||||
from .configuration_bert_masked import MaskedBertConfig
|
||||
|
||||
from .modeling_bert_masked import (
|
||||
MaskedBertModel,
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertForQuestionAnswering,
|
||||
MaskedBertForSequenceClassification,
|
||||
MaskedBertForTokenClassification,
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertModel,
|
||||
)
|
||||
from .modules import *
|
||||
|
||||
@@ -19,8 +19,9 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -26,13 +26,16 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedLinear
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
ACT2FN,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertLayerNorm,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm
|
||||
|
||||
from emmental import MaskedLinear
|
||||
from emmental import MaskedBertConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
from .masked_nn import MaskedLinear
|
||||
|
||||
@@ -19,14 +19,14 @@ the weight matrix to prune a portion of the weights.
|
||||
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
|
||||
import math
|
||||
|
||||
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
class MaskedLinear(nn.Linear):
|
||||
|
||||
Reference in New Issue
Block a user