BartForCausalLM analogs to ProphetNetForCausalLM (#9128)
* initiliaze bart4causalLM * create BartDecoderWrapper, setters/getters * delete spaces * forward and additional methods * update cache function, loss function, remove ngram* params in data class. * add bartcausallm, bartdecoder testing * correct bart for causal lm * remove at * add mbart as well * up * fix typo * up * correct * add pegasusforcausallm * add blenderbotforcausallm * add blenderbotsmallforcausallm * add marianforcausallm * add test for MarianForCausalLM * add Pegasus test * add BlenderbotSmall test * add blenderbot test * fix a fail * fix an import fail * a fix * fix * Update modeling_pegasus.py * fix models * fix inputs_embeds setting getter * adapt tests * correct repo utils check * finish test improvement * fix tf models as well * make style * make fix-copies * fix copies * run all tests * last changes * fix all tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -130,6 +130,12 @@ BartForQuestionAnswering
|
|||||||
.. autoclass:: transformers.BartForQuestionAnswering
|
.. autoclass:: transformers.BartForQuestionAnswering
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
BartForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BartForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TFBartModel
|
TFBartModel
|
||||||
|
|||||||
@@ -98,6 +98,13 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
BlenderbotForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BlenderbotForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
TFBlenderbotModel
|
TFBlenderbotModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,13 @@ BlenderbotSmallForConditionalGeneration
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
BlenderbotSmallForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BlenderbotSmallForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
TFBlenderbotSmallModel
|
TFBlenderbotSmallModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -193,6 +193,13 @@ MarianMTModel
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
MarianForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MarianForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
TFMarianModel
|
TFMarianModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -124,6 +124,13 @@ MBartForSequenceClassification
|
|||||||
.. autoclass:: transformers.MBartForSequenceClassification
|
.. autoclass:: transformers.MBartForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
MBartForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MBartForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
TFMBartModel
|
TFMBartModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -131,6 +131,13 @@ PegasusForConditionalGeneration
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
PegasusForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.PegasusForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
TFPegasusModel
|
TFPegasusModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -431,6 +431,7 @@ if is_torch_available():
|
|||||||
_import_structure["models.bart"].extend(
|
_import_structure["models.bart"].extend(
|
||||||
[
|
[
|
||||||
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BartForCausalLM",
|
||||||
"BartForConditionalGeneration",
|
"BartForConditionalGeneration",
|
||||||
"BartForQuestionAnswering",
|
"BartForQuestionAnswering",
|
||||||
"BartForSequenceClassification",
|
"BartForSequenceClassification",
|
||||||
@@ -468,6 +469,7 @@ if is_torch_available():
|
|||||||
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"BlenderbotForConditionalGeneration",
|
"BlenderbotForConditionalGeneration",
|
||||||
"BlenderbotModel",
|
"BlenderbotModel",
|
||||||
|
"BlenderbotForCausalLM",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot_small"].extend(
|
_import_structure["models.blenderbot_small"].extend(
|
||||||
@@ -475,6 +477,7 @@ if is_torch_available():
|
|||||||
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"BlenderbotSmallForConditionalGeneration",
|
"BlenderbotSmallForConditionalGeneration",
|
||||||
"BlenderbotSmallModel",
|
"BlenderbotSmallModel",
|
||||||
|
"BlenderbotSmallForCausalLM",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.camembert"].extend(
|
_import_structure["models.camembert"].extend(
|
||||||
@@ -628,9 +631,10 @@ if is_torch_available():
|
|||||||
"LxmertXLayer",
|
"LxmertXLayer",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel"])
|
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"])
|
||||||
_import_structure["models.mbart"].extend(
|
_import_structure["models.mbart"].extend(
|
||||||
[
|
[
|
||||||
|
"MBartForCausalLM",
|
||||||
"MBartForConditionalGeneration",
|
"MBartForConditionalGeneration",
|
||||||
"MBartForQuestionAnswering",
|
"MBartForQuestionAnswering",
|
||||||
"MBartForSequenceClassification",
|
"MBartForSequenceClassification",
|
||||||
@@ -679,7 +683,9 @@ if is_torch_available():
|
|||||||
"load_tf_weights_in_openai_gpt",
|
"load_tf_weights_in_openai_gpt",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.pegasus"].extend(["PegasusForConditionalGeneration", "PegasusModel"])
|
_import_structure["models.pegasus"].extend(
|
||||||
|
["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"]
|
||||||
|
)
|
||||||
_import_structure["models.prophetnet"].extend(
|
_import_structure["models.prophetnet"].extend(
|
||||||
[
|
[
|
||||||
"PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1517,6 +1523,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.bart import (
|
from .models.bart import (
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BartForCausalLM,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
@@ -1546,11 +1553,13 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.blenderbot import (
|
from .models.blenderbot import (
|
||||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BlenderbotForCausalLM,
|
||||||
BlenderbotForConditionalGeneration,
|
BlenderbotForConditionalGeneration,
|
||||||
BlenderbotModel,
|
BlenderbotModel,
|
||||||
)
|
)
|
||||||
from .models.blenderbot_small import (
|
from .models.blenderbot_small import (
|
||||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BlenderbotSmallForCausalLM,
|
||||||
BlenderbotSmallForConditionalGeneration,
|
BlenderbotSmallForConditionalGeneration,
|
||||||
BlenderbotSmallModel,
|
BlenderbotSmallModel,
|
||||||
)
|
)
|
||||||
@@ -1691,8 +1700,9 @@ if TYPE_CHECKING:
|
|||||||
LxmertVisualFeatureEncoder,
|
LxmertVisualFeatureEncoder,
|
||||||
LxmertXLayer,
|
LxmertXLayer,
|
||||||
)
|
)
|
||||||
from .models.marian import MarianModel, MarianMTModel
|
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||||
from .models.mbart import (
|
from .models.mbart import (
|
||||||
|
MBartForCausalLM,
|
||||||
MBartForConditionalGeneration,
|
MBartForConditionalGeneration,
|
||||||
MBartForQuestionAnswering,
|
MBartForQuestionAnswering,
|
||||||
MBartForSequenceClassification,
|
MBartForSequenceClassification,
|
||||||
@@ -1734,7 +1744,7 @@ if TYPE_CHECKING:
|
|||||||
OpenAIGPTPreTrainedModel,
|
OpenAIGPTPreTrainedModel,
|
||||||
load_tf_weights_in_openai_gpt,
|
load_tf_weights_in_openai_gpt,
|
||||||
)
|
)
|
||||||
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel
|
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
|
||||||
from .models.prophetnet import (
|
from .models.prophetnet import (
|
||||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
ProphetNetDecoder,
|
ProphetNetDecoder,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from ..albert.modeling_albert import (
|
|||||||
AlbertModel,
|
AlbertModel,
|
||||||
)
|
)
|
||||||
from ..bart.modeling_bart import (
|
from ..bart.modeling_bart import (
|
||||||
|
BartForCausalLM,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
@@ -50,8 +51,12 @@ from ..bert.modeling_bert import (
|
|||||||
BertModel,
|
BertModel,
|
||||||
)
|
)
|
||||||
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
||||||
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
|
from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel
|
||||||
from ..blenderbot_small.modeling_blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel
|
from ..blenderbot_small.modeling_blenderbot_small import (
|
||||||
|
BlenderbotSmallForCausalLM,
|
||||||
|
BlenderbotSmallForConditionalGeneration,
|
||||||
|
BlenderbotSmallModel,
|
||||||
|
)
|
||||||
from ..camembert.modeling_camembert import (
|
from ..camembert.modeling_camembert import (
|
||||||
CamembertForCausalLM,
|
CamembertForCausalLM,
|
||||||
CamembertForMaskedLM,
|
CamembertForMaskedLM,
|
||||||
@@ -138,8 +143,9 @@ from ..longformer.modeling_longformer import (
|
|||||||
LongformerModel,
|
LongformerModel,
|
||||||
)
|
)
|
||||||
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||||
from ..marian.modeling_marian import MarianModel, MarianMTModel
|
from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||||
from ..mbart.modeling_mbart import (
|
from ..mbart.modeling_mbart import (
|
||||||
|
MBartForCausalLM,
|
||||||
MBartForConditionalGeneration,
|
MBartForConditionalGeneration,
|
||||||
MBartForQuestionAnswering,
|
MBartForQuestionAnswering,
|
||||||
MBartForSequenceClassification,
|
MBartForSequenceClassification,
|
||||||
@@ -165,7 +171,7 @@ from ..mpnet.modeling_mpnet import (
|
|||||||
)
|
)
|
||||||
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||||
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
|
from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
|
||||||
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
|
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
|
||||||
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
||||||
RagModel,
|
RagModel,
|
||||||
@@ -425,6 +431,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
|||||||
(BertGenerationConfig, BertGenerationDecoder),
|
(BertGenerationConfig, BertGenerationDecoder),
|
||||||
(XLMProphetNetConfig, XLMProphetNetForCausalLM),
|
(XLMProphetNetConfig, XLMProphetNetForCausalLM),
|
||||||
(ProphetNetConfig, ProphetNetForCausalLM),
|
(ProphetNetConfig, ProphetNetForCausalLM),
|
||||||
|
(BartConfig, BartForCausalLM),
|
||||||
|
(MBartConfig, MBartForCausalLM),
|
||||||
|
(PegasusConfig, PegasusForCausalLM),
|
||||||
|
(MarianConfig, MarianForCausalLM),
|
||||||
|
(BlenderbotConfig, BlenderbotForCausalLM),
|
||||||
|
(BlenderbotSmallConfig, BlenderbotSmallForCausalLM),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_tokenizers_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_bart"] = [
|
_import_structure["modeling_bart"] = [
|
||||||
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BartForCausalLM",
|
||||||
"BartForConditionalGeneration",
|
"BartForConditionalGeneration",
|
||||||
"BartForQuestionAnswering",
|
"BartForQuestionAnswering",
|
||||||
"BartForSequenceClassification",
|
"BartForSequenceClassification",
|
||||||
@@ -53,6 +54,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_bart import (
|
from .modeling_bart import (
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BartForCausalLM,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
|
|||||||
@@ -13,8 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch BART model. """
|
""" PyTorch BART model. """
|
||||||
|
import copy
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
@@ -37,6 +36,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
Seq2SeqQuestionAnsweringModelOutput,
|
Seq2SeqQuestionAnsweringModelOutput,
|
||||||
@@ -843,6 +843,30 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -945,19 +969,9 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -975,7 +989,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -1012,7 +1026,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -1023,7 +1037,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
@@ -1039,7 +1053,9 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -1571,3 +1587,208 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BartDecoderWrapper(BartPretrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = BartDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BartForCausalLM(BartPretrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = BartDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import BartTokenizer, BartForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = BartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -915,7 +915,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
|||||||
"BlenderbotForConditionalGeneration",
|
"BlenderbotForConditionalGeneration",
|
||||||
"BlenderbotModel",
|
"BlenderbotModel",
|
||||||
"BlenderbotPreTrainedModel",
|
"BlenderbotPreTrainedModel",
|
||||||
|
"BlenderbotForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -46,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_blenderbot import (
|
from .modeling_blenderbot import (
|
||||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BlenderbotForCausalLM,
|
||||||
BlenderbotForConditionalGeneration,
|
BlenderbotForConditionalGeneration,
|
||||||
BlenderbotModel,
|
BlenderbotModel,
|
||||||
BlenderbotPreTrainedModel,
|
BlenderbotPreTrainedModel,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" PyTorch Blenderbot model. """
|
""" PyTorch Blenderbot model. """
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -37,6 +38,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -805,13 +807,38 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -838,12 +865,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
|
||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the heas is **masked**.
|
|
||||||
|
|
||||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
of the decoder.
|
of the decoder.
|
||||||
@@ -855,6 +876,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
@@ -907,19 +934,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -929,7 +946,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||||
|
|
||||||
# in constrast to Bart, Blenderbot applies layernorm on inputs_embeds
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -937,7 +953,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -974,7 +990,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -985,10 +1001,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -1001,7 +1017,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add final layer norm
|
# add final layer norm
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
@@ -1336,3 +1354,210 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot
|
||||||
|
class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = BlenderbotDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot
|
||||||
|
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = BlenderbotDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import BlenderbotTokenizer, BlenderbotForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = BlenderbotTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = BlenderbotForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -925,7 +925,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
"BlenderbotSmallForConditionalGeneration",
|
"BlenderbotSmallForConditionalGeneration",
|
||||||
"BlenderbotSmallModel",
|
"BlenderbotSmallModel",
|
||||||
"BlenderbotSmallPreTrainedModel",
|
"BlenderbotSmallPreTrainedModel",
|
||||||
|
"BlenderbotSmallForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -46,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_blenderbot_small import (
|
from .modeling_blenderbot_small import (
|
||||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
BlenderbotSmallForCausalLM,
|
||||||
BlenderbotSmallForConditionalGeneration,
|
BlenderbotSmallForConditionalGeneration,
|
||||||
BlenderbotSmallModel,
|
BlenderbotSmallModel,
|
||||||
BlenderbotSmallPreTrainedModel,
|
BlenderbotSmallPreTrainedModel,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" PyTorch BlenderbotSmall model. """
|
""" PyTorch BlenderbotSmall model. """
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -35,6 +36,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -805,6 +807,31 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -907,19 +934,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -938,7 +955,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
@@ -974,7 +991,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -985,7 +1002,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
@@ -1001,7 +1018,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -1310,3 +1329,210 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall
|
||||||
|
class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = BlenderbotSmallDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall
|
||||||
|
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = BlenderbotSmallDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotSmallForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = BlenderbotSmallForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -925,7 +925,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ if is_torch_available():
|
|||||||
"MarianModel",
|
"MarianModel",
|
||||||
"MarianMTModel",
|
"MarianMTModel",
|
||||||
"MarianPreTrainedModel",
|
"MarianPreTrainedModel",
|
||||||
|
"MarianForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_marian import (
|
from .modeling_marian import (
|
||||||
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
MarianForCausalLM,
|
||||||
MarianModel,
|
MarianModel,
|
||||||
MarianMTModel,
|
MarianMTModel,
|
||||||
MarianPreTrainedModel,
|
MarianPreTrainedModel,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""PyTorch MarianMTModel model, ported from the Marian C++ repo."""
|
"""PyTorch MarianMTModel model, ported from the Marian C++ repo."""
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -36,6 +37,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -809,6 +811,31 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
|
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -911,19 +938,9 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -940,7 +957,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -977,7 +994,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -988,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
@@ -1004,7 +1021,9 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -1321,3 +1340,210 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian
|
||||||
|
class MarianDecoderWrapper(MarianPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = MarianDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian
|
||||||
|
class MarianForCausalLM(MarianPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = MarianDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import MarianTokenizer, MarianForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = MarianTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = MarianForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -943,7 +943,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ if is_tokenizers_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_mbart"] = [
|
_import_structure["modeling_mbart"] = [
|
||||||
"MBART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"MBART_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"MBartForCausalLM",
|
||||||
"MBartForConditionalGeneration",
|
"MBartForConditionalGeneration",
|
||||||
"MBartForQuestionAnswering",
|
"MBartForQuestionAnswering",
|
||||||
"MBartForSequenceClassification",
|
"MBartForSequenceClassification",
|
||||||
@@ -62,6 +63,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_mbart import (
|
from .modeling_mbart import (
|
||||||
MBART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
MBART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
MBartForCausalLM,
|
||||||
MBartForConditionalGeneration,
|
MBartForConditionalGeneration,
|
||||||
MBartForQuestionAnswering,
|
MBartForQuestionAnswering,
|
||||||
MBartForSequenceClassification,
|
MBartForSequenceClassification,
|
||||||
|
|||||||
@@ -13,8 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch MBART model. """
|
""" PyTorch MBART model. """
|
||||||
|
import copy
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -36,6 +35,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
Seq2SeqQuestionAnsweringModelOutput,
|
Seq2SeqQuestionAnsweringModelOutput,
|
||||||
@@ -852,6 +852,31 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -954,19 +979,9 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -984,7 +999,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -1021,7 +1036,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -1032,7 +1047,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
@@ -1048,7 +1063,9 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
@@ -1570,3 +1587,210 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
|
||||||
|
class MBartDecoderWrapper(MBartPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = MBartDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart
|
||||||
|
class MBartForCausalLM(MBartPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = MBartDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import MBartTokenizer, MBartForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = MBartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -938,7 +938,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ if is_torch_available():
|
|||||||
"PegasusForConditionalGeneration",
|
"PegasusForConditionalGeneration",
|
||||||
"PegasusModel",
|
"PegasusModel",
|
||||||
"PegasusPreTrainedModel",
|
"PegasusPreTrainedModel",
|
||||||
|
"PegasusForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_pegasus import (
|
from .modeling_pegasus import (
|
||||||
PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
PegasusForCausalLM,
|
||||||
PegasusForConditionalGeneration,
|
PegasusForConditionalGeneration,
|
||||||
PegasusModel,
|
PegasusModel,
|
||||||
PegasusPreTrainedModel,
|
PegasusPreTrainedModel,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch PEGASUS model. """
|
""" PyTorch PEGASUS model. """
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -36,6 +36,7 @@ from ...file_utils import (
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -817,6 +818,31 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -919,19 +945,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
combined_attention_mask = None
|
)
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -948,7 +964,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -985,7 +1001,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -996,7 +1012,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
@@ -1012,7 +1028,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
@@ -1325,3 +1343,210 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
|
||||||
|
class PegasusDecoderWrapper(PegasusPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = PegasusDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus
|
||||||
|
class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = PegasusDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import PegasusTokenizer, PegasusForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = PegasusTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> model = PegasusForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -955,7 +955,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None:
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1330,7 +1330,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
||||||
>>> model = ProphetNetDecoder.from_pretrained('patrickvonplaten/prophetnet-large-uncased-standalone', add_cross_attention=False)
|
>>> model = ProphetNetDecoder.from_pretrained('microsoft/prophetnet-large-uncased', add_cross_attention=False)
|
||||||
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|||||||
@@ -431,6 +431,11 @@ class AutoModelWithLMHead:
|
|||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class BartForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class BartForConditionalGeneration:
|
class BartForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -596,6 +601,11 @@ def load_tf_weights_in_bert_generation(*args, **kwargs):
|
|||||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class BlenderbotForConditionalGeneration:
|
class BlenderbotForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -617,6 +627,11 @@ class BlenderbotModel:
|
|||||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotSmallForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class BlenderbotSmallForConditionalGeneration:
|
class BlenderbotSmallForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -1464,6 +1479,11 @@ class LxmertXLayer:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MarianForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class MarianModel:
|
class MarianModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -1482,6 +1502,11 @@ class MarianMTModel:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MBartForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class MBartForConditionalGeneration:
|
class MBartForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -1772,6 +1797,11 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs):
|
|||||||
requires_pytorch(load_tf_weights_in_openai_gpt)
|
requires_pytorch(load_tf_weights_in_openai_gpt)
|
||||||
|
|
||||||
|
|
||||||
|
class PegasusForCausalLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class PegasusForConditionalGeneration:
|
class PegasusForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|||||||
@@ -1522,6 +1522,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
)
|
)
|
||||||
{% else %}
|
{% else %}
|
||||||
import math
|
import math
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -1663,6 +1664,7 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -1730,6 +1732,13 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -1778,19 +1787,30 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -1849,6 +1869,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -1861,6 +1883,10 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -1876,6 +1902,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -1894,6 +1921,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -1924,7 +1952,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
@@ -2036,6 +2064,18 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
|
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
|
||||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -2073,6 +2113,35 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
{{cookiecutter.uppercase_modelname}}_STANDALONE_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||||||
@@ -2113,6 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -2136,6 +2206,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -2182,7 +2257,14 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -2202,9 +2284,15 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -2253,12 +2341,39 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# create causal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
if input_shape[-1] > 1:
|
||||||
|
combined_attention_mask = _make_causal_mask(
|
||||||
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -2295,6 +2410,19 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -2340,19 +2468,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
# create causal mask
|
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = None
|
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
if attention_mask is not None and combined_attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
@@ -2370,8 +2486,15 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -2398,18 +2521,22 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -2421,7 +2548,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -2486,6 +2615,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2506,6 +2637,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -2525,6 +2657,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -2603,6 +2737,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2649,6 +2785,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -2681,7 +2819,14 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -2693,6 +2838,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2919,4 +3065,210 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
||||||
|
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
||||||
|
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.is_decoder = True
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = {{cookiecutter.camelcase_modelname}}DecoderWrapper(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.{{cookiecutter.camelcase_modelname}}Tokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||||
|
config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||||
|
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False)
|
||||||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_head_mask=encoder_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
# first step, decoder_cached_states are empty
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_generation_utils import GenerationTesterMixin
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -847,4 +847,220 @@ class {{cookiecutter.camelcase_modelname}}ModelIntegrationTests(unittest.TestCas
|
|||||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
assert generated == EXPECTED
|
assert generated == EXPECTED
|
||||||
|
|
||||||
|
|
||||||
|
class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = {{cookiecutter.camelcase_modelname}}Config(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}ForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = ({{cookiecutter.camelcase_modelname}}ForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_generation_utils import GenerationTesterMixin
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
BartConfig,
|
BartConfig,
|
||||||
|
BartForCausalLM,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
@@ -178,7 +179,7 @@ class BartModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = BartModel(config=config).to(torch_device).eval()
|
model = BartModel(config=config).to(torch_device).eval()
|
||||||
@@ -719,3 +720,242 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
assert generated_summaries == EXPECTED
|
assert generated_summaries == EXPECTED
|
||||||
|
|
||||||
|
|
||||||
|
class BartStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
encoder_layers=self.decoder_layers,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = BartDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = BartDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = BartStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Blenderbot model. """
|
""" Testing suite for the PyTorch Blenderbot model. """
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -31,7 +30,11 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotTokenizer
|
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotTokenizer
|
||||||
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotDecoder, BlenderbotEncoder
|
from transformers.models.blenderbot.modeling_blenderbot import (
|
||||||
|
BlenderbotDecoder,
|
||||||
|
BlenderbotEncoder,
|
||||||
|
BlenderbotForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_blenderbot_inputs_dict(
|
def prepare_blenderbot_inputs_dict(
|
||||||
@@ -165,7 +168,7 @@ class BlenderbotModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = BlenderbotModel(config=config).to(torch_device).eval()
|
model = BlenderbotModel(config=config).to(torch_device).eval()
|
||||||
@@ -300,3 +303,222 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
assert "I think it's because we are so worried about what people think of us." == reply.strip()
|
assert "I think it's because we are so worried about what people think of us." == reply.strip()
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = BlenderbotConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = BlenderbotDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = BlenderbotDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
# past_key_values = model(input_ids, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (BlenderbotDecoder, BlenderbotForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (BlenderbotForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = BlenderbotStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch BlenderbotSmall model. """
|
""" Testing suite for the PyTorch BlenderbotSmall model. """
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -39,6 +38,7 @@ if is_torch_available():
|
|||||||
from transformers.models.blenderbot_small.modeling_blenderbot_small import (
|
from transformers.models.blenderbot_small.modeling_blenderbot_small import (
|
||||||
BlenderbotSmallDecoder,
|
BlenderbotSmallDecoder,
|
||||||
BlenderbotSmallEncoder,
|
BlenderbotSmallEncoder,
|
||||||
|
BlenderbotSmallForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ class BlenderbotSmallModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = BlenderbotSmallModel(config=config).to(torch_device).eval()
|
model = BlenderbotSmallModel(config=config).to(torch_device).eval()
|
||||||
@@ -317,3 +317,221 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
|||||||
"have you ever been to a sam club? it's a great club in the south.",
|
"have you ever been to a sam club? it's a great club in the south.",
|
||||||
"have you ever heard of sam harris? he's an american singer, songwriter, and actor.",
|
"have you ever heard of sam harris? he's an american singer, songwriter, and actor.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotSmallStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = BlenderbotSmallConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = BlenderbotSmallDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = BlenderbotSmallDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (BlenderbotSmallDecoder, BlenderbotSmallForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (BlenderbotSmallForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = BlenderbotSmallStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import unittest
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_modeling_bart import BartStandaloneDecoderModelTester
|
||||||
from .test_modeling_bert import BertModelTester
|
from .test_modeling_bert import BertModelTester
|
||||||
from .test_modeling_bert_generation import BertGenerationEncoderTester
|
from .test_modeling_bert_generation import BertGenerationEncoderTester
|
||||||
from .test_modeling_common import ids_tensor
|
from .test_modeling_common import ids_tensor
|
||||||
@@ -34,6 +35,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BartForCausalLM,
|
||||||
BertGenerationDecoder,
|
BertGenerationDecoder,
|
||||||
BertGenerationEncoder,
|
BertGenerationEncoder,
|
||||||
BertLMHeadModel,
|
BertLMHeadModel,
|
||||||
@@ -828,3 +830,57 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_encoder_decoder_model_shared_weights(self):
|
def test_encoder_decoder_model_shared_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
encoder_model = BertModel(config)
|
||||||
|
decoder_model = BartForCausalLM(decoder_config)
|
||||||
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
model_tester_encoder = BertModelTester(self, batch_size=13)
|
||||||
|
model_tester_decoder = BartStandaloneDecoderModelTester(
|
||||||
|
self, batch_size=13, d_model=32, max_position_embeddings=512
|
||||||
|
)
|
||||||
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = encoder_config_and_inputs
|
||||||
|
(
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = decoder_config_and_inputs
|
||||||
|
|
||||||
|
# make sure that cross attention layers are added
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
# disable cache for now
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
return {
|
||||||
|
"config": config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"decoder_config": decoder_config,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"labels": lm_labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_shared_weights(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Marian model. """
|
""" Testing suite for the PyTorch Marian model. """
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -45,7 +44,12 @@ if is_torch_available():
|
|||||||
convert_hf_name_to_opus_name,
|
convert_hf_name_to_opus_name,
|
||||||
convert_opus_name_to_hf_name,
|
convert_opus_name_to_hf_name,
|
||||||
)
|
)
|
||||||
from transformers.models.marian.modeling_marian import MarianDecoder, MarianEncoder, shift_tokens_right
|
from transformers.models.marian.modeling_marian import (
|
||||||
|
MarianDecoder,
|
||||||
|
MarianEncoder,
|
||||||
|
MarianForCausalLM,
|
||||||
|
shift_tokens_right,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_marian_inputs_dict(
|
def prepare_marian_inputs_dict(
|
||||||
@@ -182,7 +186,7 @@ class MarianModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = MarianModel(config=config).to(torch_device).eval()
|
model = MarianModel(config=config).to(torch_device).eval()
|
||||||
@@ -546,3 +550,221 @@ class TestConversionUtils(unittest.TestCase):
|
|||||||
"en-de",
|
"en-de",
|
||||||
]
|
]
|
||||||
self.assertListEqual(expected_opus_names, converted_opus_names)
|
self.assertListEqual(expected_opus_names, converted_opus_names)
|
||||||
|
|
||||||
|
|
||||||
|
class MarianStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = MarianConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = MarianDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = MarianDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (MarianDecoder, MarianForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (MarianForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = MarianStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=MarianConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
MBartConfig,
|
MBartConfig,
|
||||||
|
MBartForCausalLM,
|
||||||
MBartForConditionalGeneration,
|
MBartForConditionalGeneration,
|
||||||
MBartForQuestionAnswering,
|
MBartForQuestionAnswering,
|
||||||
MBartForSequenceClassification,
|
MBartForSequenceClassification,
|
||||||
@@ -174,7 +175,7 @@ class MBartModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = MBartModel(config=config).to(torch_device).eval()
|
model = MBartModel(config=config).to(torch_device).eval()
|
||||||
@@ -431,3 +432,221 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
)[0]
|
)[0]
|
||||||
self.assertEqual(prediction, "of the best books I ever read!")
|
self.assertEqual(prediction, "of the best books I ever read!")
|
||||||
|
|
||||||
|
|
||||||
|
class MBartStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = MBartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = MBartDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = MBartDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (MBartDecoder, MBartForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (MBartForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = MBartStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=MBartConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch PEGASUS model. """
|
""" Testing suite for the PyTorch PEGASUS model. """
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -32,7 +31,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
|
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
|
||||||
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder
|
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM
|
||||||
|
|
||||||
|
|
||||||
def prepare_pegasus_inputs_dict(
|
def prepare_pegasus_inputs_dict(
|
||||||
@@ -166,7 +165,7 @@ class PegasusModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = PegasusModel(config=config).to(torch_device).eval()
|
model = PegasusModel(config=config).to(torch_device).eval()
|
||||||
@@ -308,3 +307,221 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
"California's largest electricity provider has begun",
|
"California's largest electricity provider has begun",
|
||||||
"N-Dubz have revealed they were",
|
"N-Dubz have revealed they were",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PegasusStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
vocab_size=99,
|
||||||
|
batch_size=13,
|
||||||
|
d_model=16,
|
||||||
|
decoder_seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
is_decoder=True,
|
||||||
|
use_attention_mask=True,
|
||||||
|
use_cache=False,
|
||||||
|
use_labels=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
decoder_layers=4,
|
||||||
|
encoder_attention_heads=4,
|
||||||
|
decoder_attention_heads=4,
|
||||||
|
max_position_embeddings=30,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.decoder_seq_length = decoder_seq_length
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_attention_mask = use_attention_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.hidden_size = d_model
|
||||||
|
self.num_hidden_layers = decoder_layers
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.num_attention_heads = decoder_attention_heads
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
self.scope = None
|
||||||
|
self.decoder_key_length = decoder_seq_length
|
||||||
|
self.base_model_out_len = 2
|
||||||
|
self.decoder_attention_idx = 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
|
if self.use_attention_mask:
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
lm_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = PegasusConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.d_model,
|
||||||
|
decoder_layers=self.decoder_layers,
|
||||||
|
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||||
|
encoder_attention_heads=self.encoder_attention_heads,
|
||||||
|
decoder_attention_heads=self.decoder_attention_heads,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
config.use_cache = True
|
||||||
|
model = PegasusDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
):
|
||||||
|
model = PegasusDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
lm_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
def setUp(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.model_tester = PegasusStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|||||||
@@ -32,17 +32,17 @@ IGNORE_NON_TESTED = [
|
|||||||
# models to ignore for not tested
|
# models to ignore for not tested
|
||||||
"LEDEncoder", # Building part of bigger (tested) model.
|
"LEDEncoder", # Building part of bigger (tested) model.
|
||||||
"LEDDecoder", # Building part of bigger (tested) model.
|
"LEDDecoder", # Building part of bigger (tested) model.
|
||||||
"BartDecoder", # Building part of bigger (tested) model.
|
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"BartEncoder", # Building part of bigger (tested) model.
|
"BartEncoder", # Building part of bigger (tested) model.
|
||||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||||
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
||||||
"BlenderbotSmallDecoder", # Building part of bigger (tested) model.
|
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
||||||
"BlenderbotDecoder", # Building part of bigger (tested) model.
|
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"MBartEncoder", # Building part of bigger (tested) model.
|
"MBartEncoder", # Building part of bigger (tested) model.
|
||||||
"MBartDecoder", # Building part of bigger (tested) model.
|
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||||
"PegasusDecoder", # Building part of bigger (tested) model.
|
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"DPREncoder", # Building part of bigger (tested) model.
|
"DPREncoder", # Building part of bigger (tested) model.
|
||||||
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
||||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
@@ -78,11 +78,14 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"LEDEncoder",
|
"LEDEncoder",
|
||||||
"LEDDecoder",
|
"LEDDecoder",
|
||||||
"BartDecoder",
|
"BartDecoder",
|
||||||
|
"BartDecoderWrapper",
|
||||||
"BartEncoder",
|
"BartEncoder",
|
||||||
"BlenderbotSmallEncoder",
|
"BlenderbotSmallEncoder",
|
||||||
"BlenderbotSmallDecoder",
|
"BlenderbotSmallDecoder",
|
||||||
|
"BlenderbotSmallDecoderWrapper",
|
||||||
"BlenderbotEncoder",
|
"BlenderbotEncoder",
|
||||||
"BlenderbotDecoder",
|
"BlenderbotDecoder",
|
||||||
|
"BlenderbotDecoderWrapper",
|
||||||
"DPRContextEncoder",
|
"DPRContextEncoder",
|
||||||
"DPREncoder",
|
"DPREncoder",
|
||||||
"DPRReader",
|
"DPRReader",
|
||||||
@@ -93,9 +96,11 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"MT5EncoderModel",
|
"MT5EncoderModel",
|
||||||
"MBartEncoder",
|
"MBartEncoder",
|
||||||
"MBartDecoder",
|
"MBartDecoder",
|
||||||
|
"MBartDecoderWrapper",
|
||||||
"OpenAIGPTDoubleHeadsModel",
|
"OpenAIGPTDoubleHeadsModel",
|
||||||
"PegasusEncoder",
|
"PegasusEncoder",
|
||||||
"PegasusDecoder",
|
"PegasusDecoder",
|
||||||
|
"PegasusDecoderWrapper",
|
||||||
"ProphetNetDecoder",
|
"ProphetNetDecoder",
|
||||||
"ProphetNetEncoder",
|
"ProphetNetEncoder",
|
||||||
"ProphetNetDecoderWrapper",
|
"ProphetNetDecoderWrapper",
|
||||||
@@ -205,9 +210,8 @@ def find_tested_models(test_file):
|
|||||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||||
# Check with one less parenthesis
|
# Check with one less parenthesis as well
|
||||||
if len(all_models) == 0:
|
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
||||||
all_models = re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
|
||||||
if len(all_models) > 0:
|
if len(all_models) > 0:
|
||||||
model_tested = []
|
model_tested = []
|
||||||
for entry in all_models:
|
for entry in all_models:
|
||||||
|
|||||||
Reference in New Issue
Block a user