Add ElectraForCausalLM -> Enable Electra encoder-decoder model (#14729)
* Add ElectraForCausalLM and cover some basic tests & need to fix a few tests * Fix bugs * make style * make fix-copies * Update doc * Change docstring to markdown format * Remove redundant update_keys_to_ignore
This commit is contained in:
@@ -83,6 +83,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
|
|||||||
[[autodoc]] ElectraForPreTraining
|
[[autodoc]] ElectraForPreTraining
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## ElectraForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] ElectraForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
## ElectraForMaskedLM
|
## ElectraForMaskedLM
|
||||||
|
|
||||||
[[autodoc]] ElectraForMaskedLM
|
[[autodoc]] ElectraForMaskedLM
|
||||||
|
|||||||
@@ -885,6 +885,7 @@ if is_torch_available():
|
|||||||
_import_structure["models.electra"].extend(
|
_import_structure["models.electra"].extend(
|
||||||
[
|
[
|
||||||
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"ElectraForCausalLM",
|
||||||
"ElectraForMaskedLM",
|
"ElectraForMaskedLM",
|
||||||
"ElectraForMultipleChoice",
|
"ElectraForMultipleChoice",
|
||||||
"ElectraForPreTraining",
|
"ElectraForPreTraining",
|
||||||
@@ -2830,6 +2831,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.electra import (
|
from .models.electra import (
|
||||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
ElectraForCausalLM,
|
||||||
ElectraForMaskedLM,
|
ElectraForMaskedLM,
|
||||||
ElectraForMultipleChoice,
|
ElectraForMultipleChoice,
|
||||||
ElectraForPreTraining,
|
ElectraForPreTraining,
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||||
("xlnet", "XLNetLMHeadModel"),
|
("xlnet", "XLNetLMHeadModel"),
|
||||||
("xlm", "XLMWithLMHeadModel"),
|
("xlm", "XLMWithLMHeadModel"),
|
||||||
|
("electra", "ElectraForCausalLM"),
|
||||||
("ctrl", "CTRLLMHeadModel"),
|
("ctrl", "CTRLLMHeadModel"),
|
||||||
("reformer", "ReformerModelWithLMHead"),
|
("reformer", "ReformerModelWithLMHead"),
|
||||||
("bert-generation", "BertGenerationDecoder"),
|
("bert-generation", "BertGenerationDecoder"),
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ if is_tokenizers_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_electra"] = [
|
_import_structure["modeling_electra"] = [
|
||||||
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"ElectraForCausalLM",
|
||||||
"ElectraForMaskedLM",
|
"ElectraForMaskedLM",
|
||||||
"ElectraForMultipleChoice",
|
"ElectraForMultipleChoice",
|
||||||
"ElectraForPreTraining",
|
"ElectraForPreTraining",
|
||||||
@@ -79,6 +80,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_electra import (
|
from .modeling_electra import (
|
||||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
ElectraForCausalLM,
|
||||||
ElectraForMaskedLM,
|
ElectraForMaskedLM,
|
||||||
ElectraForMultipleChoice,
|
ElectraForMultipleChoice,
|
||||||
ElectraForPreTraining,
|
ElectraForPreTraining,
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithCrossAttentions,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
MultipleChoiceModelOutput,
|
MultipleChoiceModelOutput,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
@@ -846,6 +847,10 @@ class ElectraModel(ElectraPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -868,6 +873,9 @@ class ElectraModel(ElectraPreTrainedModel):
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(input_shape, device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
@@ -879,10 +887,26 @@ class ElectraModel(ElectraPreTrainedModel):
|
|||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(self, "embeddings_project"):
|
if hasattr(self, "embeddings_project"):
|
||||||
@@ -892,6 +916,10 @@ class ElectraModel(ElectraPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -969,14 +997,14 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = discriminator_hidden_states[0]
|
sequence_output = discriminator_hidden_states[0]
|
||||||
@@ -1075,14 +1103,14 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||||
|
|
||||||
@@ -1166,14 +1194,14 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
generator_hidden_states = self.electra(
|
generator_hidden_states = self.electra(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
generator_sequence_output = generator_hidden_states[0]
|
generator_sequence_output = generator_hidden_states[0]
|
||||||
|
|
||||||
@@ -1247,14 +1275,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||||
|
|
||||||
@@ -1481,3 +1509,152 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
|||||||
hidden_states=discriminator_hidden_states.hidden_states,
|
hidden_states=discriminator_hidden_states.hidden_states,
|
||||||
attentions=discriminator_hidden_states.attentions,
|
attentions=discriminator_hidden_states.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning. """, ELECTRA_START_DOCSTRING
|
||||||
|
)
|
||||||
|
class ElectraForCausalLM(ElectraPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not config.is_decoder:
|
||||||
|
logger.warning("If you want to use `ElectraLMHeadModel` as a standalone, add `is_decoder=True.`")
|
||||||
|
|
||||||
|
self.electra = ElectraModel(config)
|
||||||
|
self.generator_predictions = ElectraGeneratorPredictions(config)
|
||||||
|
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.generator_lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.generator_lm_head = new_embeddings
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
|
||||||
|
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see `past_key_values`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import ElectraTokenizer, ElectraForCausalLM, ElectraConfig
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-generator")
|
||||||
|
>>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
|
||||||
|
>>> config.is_decoder = True
|
||||||
|
>>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> prediction_logits = outputs.logits
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
outputs = self.electra(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
|
||||||
|
|
||||||
|
lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (prediction_scores,) + outputs[1:]
|
||||||
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=lm_loss,
|
||||||
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|
||||||
|
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -1924,6 +1924,18 @@ class DPRReader:
|
|||||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class ElectraForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ElectraForMaskedLM:
|
class ElectraForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from transformers.models.auto import get_values
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
ElectraForCausalLM,
|
||||||
ElectraForMaskedLM,
|
ElectraForMaskedLM,
|
||||||
ElectraForMultipleChoice,
|
ElectraForMultipleChoice,
|
||||||
ElectraForPreTraining,
|
ElectraForPreTraining,
|
||||||
@@ -117,6 +118,34 @@ class ElectraModelTester:
|
|||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
_,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_electra_model(
|
def create_and_check_electra_model(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -136,6 +165,38 @@ class ElectraModelTester:
|
|||||||
result = model(input_ids)
|
result = model(input_ids)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_electra_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = ElectraModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
def create_and_check_electra_for_masked_lm(
|
def create_and_check_electra_for_masked_lm(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -153,6 +214,24 @@ class ElectraModelTester:
|
|||||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_electra_for_causal_lm(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
model = ElectraForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
def create_and_check_electra_for_token_classification(
|
def create_and_check_electra_for_token_classification(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -281,6 +360,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
ElectraModel,
|
ElectraModel,
|
||||||
ElectraForPreTraining,
|
ElectraForPreTraining,
|
||||||
ElectraForMaskedLM,
|
ElectraForMaskedLM,
|
||||||
|
ElectraForCausalLM,
|
||||||
ElectraForMultipleChoice,
|
ElectraForMultipleChoice,
|
||||||
ElectraForTokenClassification,
|
ElectraForTokenClassification,
|
||||||
ElectraForSequenceClassification,
|
ElectraForSequenceClassification,
|
||||||
@@ -289,6 +369,8 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else ()
|
||||||
|
|
||||||
fx_ready_model_classes = all_model_classes
|
fx_ready_model_classes = all_model_classes
|
||||||
fx_dynamic_ready_model_classes = all_model_classes
|
fx_dynamic_ready_model_classes = all_model_classes
|
||||||
|
|
||||||
@@ -314,6 +396,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_electra_model(*config_and_inputs)
|
self.model_tester.create_and_check_electra_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_electra_model_as_decoder(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_electra_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
def test_electra_model_various_embeddings(self):
|
def test_electra_model_various_embeddings(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
@@ -350,6 +436,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = ElectraModel.from_pretrained(model_name)
|
model = ElectraModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_for_causal_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_electra_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ElectraModelIntegrationTest(unittest.TestCase):
|
class ElectraModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user