Feed forward chunking others (#6365)
* Feed forward chunking for Distilbert & Albert * Added ff chunking for many other models * Change model signature * Added chunking for XLM * Cleaned up by removing some variables. * remove test_chunking flag Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
0
src/transformers/configuration_reformer.py
Normal file → Executable file
0
src/transformers/configuration_reformer.py
Normal file → Executable file
@@ -191,6 +191,7 @@ class PretrainedConfig(object):
|
|||||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||||
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||||
|
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forwar", 0)
|
||||||
|
|
||||||
# task specific arguments
|
# task specific arguments
|
||||||
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
||||||
|
|||||||
19
src/transformers/modeling_albert.py
Normal file → Executable file
19
src/transformers/modeling_albert.py
Normal file → Executable file
@@ -43,7 +43,7 @@ from .modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
|
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
""" Load tf checkpoints in a pytorch model."""
|
""" Load tf checkpoints in a pytorch model."""
|
||||||
try:
|
try:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -286,6 +287,8 @@ class AlbertLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = AlbertAttention(config)
|
self.attention = AlbertAttention(config)
|
||||||
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@@ -297,14 +300,20 @@ class AlbertLayer(nn.Module):
|
|||||||
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
|
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
|
||||||
):
|
):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
||||||
ffn_output = self.ffn(attention_output[0])
|
|
||||||
ffn_output = self.activation(ffn_output)
|
ffn_output = apply_chunking_to_forward(
|
||||||
ffn_output = self.ffn_output(ffn_output)
|
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0],
|
||||||
ffn_output = self.dropout(ffn_output)
|
)
|
||||||
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
||||||
|
|
||||||
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||||
|
|
||||||
|
def ff_chunk(self, attention_output):
|
||||||
|
ffn_output = self.ffn(attention_output)
|
||||||
|
ffn_output = self.activation(ffn_output)
|
||||||
|
ffn_output = self.ffn_output(ffn_output)
|
||||||
|
return ffn_output
|
||||||
|
|
||||||
|
|
||||||
class AlbertLayerGroup(nn.Module):
|
class AlbertLayerGroup(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
@@ -424,7 +424,7 @@ class BertLayer(nn.Module):
|
|||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
12
src/transformers/modeling_distilbert.py
Normal file → Executable file
12
src/transformers/modeling_distilbert.py
Normal file → Executable file
@@ -44,7 +44,12 @@ from .modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -208,6 +213,8 @@ class FFN(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = nn.Dropout(p=config.dropout)
|
self.dropout = nn.Dropout(p=config.dropout)
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
|
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
|
||||||
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
|
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
|
||||||
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
|
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
|
||||||
@@ -216,6 +223,9 @@ class FFN(nn.Module):
|
|||||||
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
|
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
|
||||||
|
|
||||||
|
def ff_chunk(self, input):
|
||||||
x = self.lin1(input)
|
x = self.lin1(input)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.lin2(x)
|
x = self.lin2(x)
|
||||||
|
|||||||
19
src/transformers/modeling_longformer.py
Normal file → Executable file
19
src/transformers/modeling_longformer.py
Normal file → Executable file
@@ -41,7 +41,12 @@ from .modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
|
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -685,6 +690,8 @@ class LongformerLayer(nn.Module):
|
|||||||
self.attention = LongformerAttention(config, layer_id)
|
self.attention = LongformerAttention(config, layer_id)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, attention_mask=None, output_attentions=False,
|
self, hidden_states, attention_mask=None, output_attentions=False,
|
||||||
@@ -693,11 +700,17 @@ class LongformerLayer(nn.Module):
|
|||||||
attn_output = self_attn_outputs[0]
|
attn_output = self_attn_outputs[0]
|
||||||
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
intermediate_output = self.intermediate(attn_output)
|
layer_output = apply_chunking_to_forward(
|
||||||
layer_output = self.output(intermediate_output, attn_output)
|
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
|
||||||
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def ff_chunk(self, attn_output):
|
||||||
|
intermediate_output = self.intermediate(attn_output)
|
||||||
|
layer_output = self.output(intermediate_output, attn_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class LongformerEncoder(nn.Module):
|
class LongformerEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
4
src/transformers/modeling_reformer.py
Normal file → Executable file
4
src/transformers/modeling_reformer.py
Normal file → Executable file
@@ -1369,7 +1369,7 @@ class ChunkReformerFeedForward(nn.Module):
|
|||||||
|
|
||||||
def forward(self, attention_output):
|
def forward(self, attention_output):
|
||||||
return apply_chunking_to_forward(
|
return apply_chunking_to_forward(
|
||||||
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,
|
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_chunk(self, hidden_states):
|
def forward_chunk(self, hidden_states):
|
||||||
@@ -1730,7 +1730,7 @@ class ReformerOnlyLMHead(nn.Module):
|
|||||||
self.decoder.bias = self.bias
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
|
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||||
|
|
||||||
def forward_chunk(self, hidden_states):
|
def forward_chunk(self, hidden_states):
|
||||||
hidden_states = self.decoder(hidden_states)
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
|||||||
8
src/transformers/modeling_utils.py
Normal file → Executable file
8
src/transformers/modeling_utils.py
Normal file → Executable file
@@ -1519,7 +1519,7 @@ def prune_layer(
|
|||||||
|
|
||||||
|
|
||||||
def apply_chunking_to_forward(
|
def apply_chunking_to_forward(
|
||||||
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
|
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
|
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
|
||||||
@@ -1529,12 +1529,12 @@ def apply_chunking_to_forward(
|
|||||||
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
|
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
forward_fn (:obj:`Callable[..., torch.Tensor]`):
|
||||||
|
The forward function of the model.
|
||||||
chunk_size (:obj:`int`):
|
chunk_size (:obj:`int`):
|
||||||
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
|
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||||
chunk_dim (:obj:`int`):
|
chunk_dim (:obj:`int`):
|
||||||
The dimension over which the :obj:`input_tensors` should be chunked.
|
The dimension over which the :obj:`input_tensors` should be chunked.
|
||||||
forward_fn (:obj:`Callable[..., torch.Tensor]`):
|
|
||||||
The forward function of the model.
|
|
||||||
input_tensors (:obj:`Tuple[torch.Tensor]`):
|
input_tensors (:obj:`Tuple[torch.Tensor]`):
|
||||||
The input tensors of ``forward_fn`` which will be chunked.
|
The input tensors of ``forward_fn`` which will be chunked.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1550,7 +1550,7 @@ def apply_chunking_to_forward(
|
|||||||
|
|
||||||
# implement a chunked forward function
|
# implement a chunked forward function
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
|
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
|
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
|
||||||
|
|||||||
6
src/transformers/modeling_xlm.py
Normal file → Executable file
6
src/transformers/modeling_xlm.py
Normal file → Executable file
@@ -50,6 +50,7 @@ from .modeling_utils import (
|
|||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
SequenceSummary,
|
SequenceSummary,
|
||||||
SQuADHead,
|
SQuADHead,
|
||||||
|
apply_chunking_to_forward,
|
||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
@@ -212,8 +213,13 @@ class TransformerFFN(nn.Module):
|
|||||||
self.lin1 = nn.Linear(in_dim, dim_hidden)
|
self.lin1 = nn.Linear(in_dim, dim_hidden)
|
||||||
self.lin2 = nn.Linear(dim_hidden, out_dim)
|
self.lin2 = nn.Linear(dim_hidden, out_dim)
|
||||||
self.act = gelu if config.gelu_activation else F.relu
|
self.act = gelu if config.gelu_activation else F.relu
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
|
||||||
|
|
||||||
|
def ff_chunk(self, input):
|
||||||
x = self.lin1(input)
|
x = self.lin1(input)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
x = self.lin2(x)
|
x = self.lin2(x)
|
||||||
|
|||||||
21
src/transformers/modeling_xlnet.py
Normal file → Executable file
21
src/transformers/modeling_xlnet.py
Normal file → Executable file
@@ -35,7 +35,14 @@ from .file_utils import (
|
|||||||
add_start_docstrings_to_callable,
|
add_start_docstrings_to_callable,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
|
from .modeling_utils import (
|
||||||
|
PoolerAnswerClass,
|
||||||
|
PoolerEndLogits,
|
||||||
|
PoolerStartLogits,
|
||||||
|
PreTrainedModel,
|
||||||
|
SequenceSummary,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -495,6 +502,8 @@ class XLNetLayer(nn.Module):
|
|||||||
self.rel_attn = XLNetRelativeAttention(config)
|
self.rel_attn = XLNetRelativeAttention(config)
|
||||||
self.ff = XLNetFeedForward(config)
|
self.ff = XLNetFeedForward(config)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -524,12 +533,18 @@ class XLNetLayer(nn.Module):
|
|||||||
output_h, output_g = outputs[:2]
|
output_h, output_g = outputs[:2]
|
||||||
|
|
||||||
if output_g is not None:
|
if output_g is not None:
|
||||||
output_g = self.ff(output_g)
|
output_g = apply_chunking_to_forward(
|
||||||
output_h = self.ff(output_h)
|
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
|
||||||
|
)
|
||||||
|
output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)
|
||||||
|
|
||||||
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
|
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def ff_chunk(self, output_x):
|
||||||
|
output_x = self.ff(output_x)
|
||||||
|
return output_x
|
||||||
|
|
||||||
|
|
||||||
class XLNetPreTrainedModel(PreTrainedModel):
|
class XLNetPreTrainedModel(PreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
|
|||||||
7
tests/test_modeling_bert.py
Normal file → Executable file
7
tests/test_modeling_bert.py
Normal file → Executable file
@@ -26,15 +26,15 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertModel,
|
|
||||||
BertLMHeadModel,
|
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
|
BertForMultipleChoice,
|
||||||
BertForNextSentencePrediction,
|
BertForNextSentencePrediction,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
BertForTokenClassification,
|
BertForTokenClassification,
|
||||||
BertForMultipleChoice,
|
BertLMHeadModel,
|
||||||
|
BertModel,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
@@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_chunking = True
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BertModelTester(self)
|
self.model_tester = BertModelTester(self)
|
||||||
|
|||||||
8
tests/test_modeling_common.py
Normal file → Executable file
8
tests/test_modeling_common.py
Normal file → Executable file
@@ -25,15 +25,15 @@ from transformers.testing_utils import require_multigpu, require_torch, slow, to
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AdaptiveEmbedding,
|
AdaptiveEmbedding,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
BertModel,
|
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
BertModel,
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
@@ -65,7 +65,6 @@ class ModelTesterMixin:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_chunking = False
|
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
@@ -552,9 +551,6 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_chunking:
|
|
||||||
return
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
config = copy.deepcopy(original_config)
|
config = copy.deepcopy(original_config)
|
||||||
|
|||||||
@@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_chunking = True
|
|
||||||
|
|
||||||
def prepare_kwargs(self):
|
def prepare_kwargs(self):
|
||||||
return {
|
return {
|
||||||
@@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_chunking = True
|
|
||||||
|
|
||||||
def prepare_kwargs(self):
|
def prepare_kwargs(self):
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user