commplying with isort

This commit is contained in:
Victor SANH
2020-05-28 00:26:39 -04:00
parent db2a3b2e01
commit 5c8e5b3709
9 changed files with 29 additions and 28 deletions

View File

@@ -1,11 +1,9 @@
from .modules import *
from .configuration_bert_masked import MaskedBertConfig
from .modeling_bert_masked import (
MaskedBertModel,
MaskedBertForMultipleChoice,
MaskedBertForQuestionAnswering,
MaskedBertForSequenceClassification,
MaskedBertForTokenClassification,
MaskedBertForMultipleChoice,
MaskedBertModel,
)
from .modules import *

View File

@@ -19,8 +19,9 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)

View File

@@ -26,13 +26,16 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from emmental import MaskedBertConfig, MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
ACT2FN,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayerNorm,
load_tf_weights_in_bert,
)
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm
from emmental import MaskedLinear
from emmental import MaskedBertConfig
logger = logging.getLogger(__name__)

View File

@@ -1,2 +1,2 @@
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
from .masked_nn import MaskedLinear

View File

@@ -19,14 +19,14 @@ the weight matrix to prune a portion of the weights.
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
"""
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
import math
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
class MaskedLinear(nn.Linear):