Feed forward chunking others (#6365)

* Feed forward chunking for Distilbert & Albert

* Added ff chunking for many other models

* Change model signature

* Added chunking for XLM

* Cleaned up by removing some variables.

* remove test_chunking flag

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Pradhy729
2020-08-19 05:31:10 -07:00
committed by GitHub
parent fe0b85e77a
commit 2a7402cbd3
13 changed files with 78 additions and 31 deletions

7
tests/test_modeling_bert.py Normal file → Executable file
View File

@@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
from transformers import (
BertConfig,
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
BertLMHeadModel,
BertModel,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True
def setUp(self):
self.model_tester = BertModelTester(self)

8
tests/test_modeling_common.py Normal file → Executable file
View File

@@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to
if is_torch_available():
import torch
import numpy as np
import torch
from transformers import (
AdaptiveEmbedding,
PretrainedConfig,
PreTrainedModel,
BertModel,
BertConfig,
BertModel,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -65,7 +65,6 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@@ -552,9 +551,6 @@ class ModelTesterMixin:
def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)

View File

@@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {
@@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True
def prepare_kwargs(self):
return {