Add TFBartForSequenceClassification (#20570)
* read to load * base functionality * revert init * fix dummy data * moving right along * moving right along * finally * cleanup * pull out comment * add test * update docstring for main class * flake comments and rewriting copies from make repo-consistency` * remove irrelevant differences/accidental spaces * put copies back after space removals * mid * final test pass * stray comment * update test file * update test file * fixup * black * missed * black missed one more * sytle * add doc update * fix order of output class * comment * Revert "comment" This reverts commit 03f86b6948808461939cc8ad4ad74305dfb67700. * remove redundant function, and redundant reshape * move change out of common * style * put common spaces back * reorder kwargs in output * doc style
This commit is contained in:
@@ -157,6 +157,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] TFBartForConditionalGeneration
|
[[autodoc]] TFBartForConditionalGeneration
|
||||||
- call
|
- call
|
||||||
|
|
||||||
|
## TFBartForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] TFBartForSequenceClassification
|
||||||
|
- call
|
||||||
|
|
||||||
## FlaxBartModel
|
## FlaxBartModel
|
||||||
|
|
||||||
[[autodoc]] FlaxBartModel
|
[[autodoc]] FlaxBartModel
|
||||||
|
|||||||
@@ -2513,7 +2513,9 @@ else:
|
|||||||
"TFAutoModelWithLMHead",
|
"TFAutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.bart"].extend(["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"])
|
_import_structure["models.bart"].extend(
|
||||||
|
["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.bert"].extend(
|
_import_structure["models.bert"].extend(
|
||||||
[
|
[
|
||||||
"TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -5402,7 +5404,12 @@ if TYPE_CHECKING:
|
|||||||
TFAutoModelForVision2Seq,
|
TFAutoModelForVision2Seq,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
|
from .models.bart import (
|
||||||
|
TFBartForConditionalGeneration,
|
||||||
|
TFBartForSequenceClassification,
|
||||||
|
TFBartModel,
|
||||||
|
TFBartPretrainedModel,
|
||||||
|
)
|
||||||
from .models.bert import (
|
from .models.bert import (
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFBertEmbeddings,
|
TFBertEmbeddings,
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ from . import (
|
|||||||
T5Config,
|
T5Config,
|
||||||
TFAlbertForPreTraining,
|
TFAlbertForPreTraining,
|
||||||
TFBartForConditionalGeneration,
|
TFBartForConditionalGeneration,
|
||||||
|
TFBartForSequenceClassification,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
@@ -136,6 +137,7 @@ MODEL_CLASSES = {
|
|||||||
"bart": (
|
"bart": (
|
||||||
BartConfig,
|
BartConfig,
|
||||||
TFBartForConditionalGeneration,
|
TFBartForConditionalGeneration,
|
||||||
|
TFBartForSequenceClassification,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -623,6 +623,9 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
|
cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`
|
||||||
encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
@@ -643,6 +646,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|||||||
@@ -1190,7 +1190,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
|
|
||||||
return self.serving_output(output)
|
return self.serving_output(output)
|
||||||
|
|
||||||
def serving_output(output):
|
def serving_output(self, output):
|
||||||
"""
|
"""
|
||||||
Prepare the output of the saved model. Each model must implement this function.
|
Prepare the output of the saved model. Each model must implement this function.
|
||||||
|
|
||||||
|
|||||||
@@ -268,6 +268,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
("albert", "TFAlbertForSequenceClassification"),
|
("albert", "TFAlbertForSequenceClassification"),
|
||||||
|
("bart", "TFBartForSequenceClassification"),
|
||||||
("bert", "TFBertForSequenceClassification"),
|
("bert", "TFBertForSequenceClassification"),
|
||||||
("camembert", "TFCamembertForSequenceClassification"),
|
("camembert", "TFCamembertForSequenceClassification"),
|
||||||
("convbert", "TFConvBertForSequenceClassification"),
|
("convbert", "TFConvBertForSequenceClassification"),
|
||||||
|
|||||||
@@ -63,7 +63,12 @@ try:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]
|
_import_structure["modeling_tf_bart"] = [
|
||||||
|
"TFBartForConditionalGeneration",
|
||||||
|
"TFBartForSequenceClassification",
|
||||||
|
"TFBartModel",
|
||||||
|
"TFBartPretrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
@@ -116,7 +121,12 @@ if TYPE_CHECKING:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
|
from .modeling_tf_bart import (
|
||||||
|
TFBartForConditionalGeneration,
|
||||||
|
TFBartForSequenceClassification,
|
||||||
|
TFBartModel,
|
||||||
|
TFBartPretrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from ...modeling_tf_outputs import (
|
|||||||
TFBaseModelOutputWithPastAndCrossAttentions,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
|
TFSeq2SeqSequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Public API
|
# Public API
|
||||||
@@ -35,6 +36,7 @@ from ...modeling_tf_utils import (
|
|||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFModelInputType,
|
TFModelInputType,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
unpack_inputs,
|
unpack_inputs,
|
||||||
)
|
)
|
||||||
@@ -460,6 +462,24 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBartClassificationHead(tf.keras.layers.Layer):
|
||||||
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs):
|
||||||
|
super().__init__(name=name, **kwargs)
|
||||||
|
self.dense = tf.keras.layers.Dense(inner_dim, name="dense")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(pooler_dropout)
|
||||||
|
self.out_proj = tf.keras.layers.Dense(num_classes, name="out_proj")
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
hidden_states = self.dropout(inputs)
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = tf.keras.activations.tanh(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.out_proj(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFBartPretrainedModel(TFPreTrainedModel):
|
class TFBartPretrainedModel(TFPreTrainedModel):
|
||||||
config_class = BartConfig
|
config_class = BartConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
@@ -726,7 +746,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
@@ -1465,3 +1484,141 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
|
||||||
|
tasks.
|
||||||
|
""",
|
||||||
|
BART_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss):
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
pad_token = self.config.pad_token_id
|
||||||
|
input_ids = tf.constant([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
|
||||||
|
dummy_inputs = {
|
||||||
|
"attention_mask": tf.cast(tf.math.not_equal(input_ids, (pad_token)), dtype=tf.int32),
|
||||||
|
"input_ids": input_ids,
|
||||||
|
}
|
||||||
|
return dummy_inputs
|
||||||
|
|
||||||
|
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
|
||||||
|
self.classification_head = TFBartClassificationHead(
|
||||||
|
config.d_model, config.num_labels, config.classifier_dropout, name="classification_head"
|
||||||
|
)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
@unpack_inputs
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
labels: Optional[tf.Tensor] = None,
|
||||||
|
training: Optional[bool] = False,
|
||||||
|
) -> Union[TFSeq2SeqSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
||||||
|
r"""
|
||||||
|
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is None and inputs_embeds is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = outputs[0]
|
||||||
|
eos_mask = tf.equal(input_ids, self.config.eos_token_id)
|
||||||
|
# out the rows with False where present. Then verify all the final
|
||||||
|
# entries are True
|
||||||
|
self_masked = tf.reshape(tf.boolean_mask(eos_mask, eos_mask), (tf.shape(input_ids)[0], -1))
|
||||||
|
tf.Assert(tf.reduce_all(self_masked[:, -1]), ["All examples must have the same number of <eos> tokens."])
|
||||||
|
|
||||||
|
masked = tf.reshape(
|
||||||
|
tf.boolean_mask(last_hidden_state, eos_mask),
|
||||||
|
(tf.shape(input_ids)[0], tf.shape(self_masked)[1], tf.shape(last_hidden_state)[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
sentence_representation = masked[:, -1, :]
|
||||||
|
logits = self.classification_head(sentence_representation)
|
||||||
|
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TFSeq2SeqSequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def serving_output(self, output):
|
||||||
|
logits = tf.convert_to_tensor(output.logits)
|
||||||
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
|
return TFSeq2SeqSequenceClassifierOutput(
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=pkv,
|
||||||
|
decoder_hidden_states=dec_hs,
|
||||||
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=enc_hs,
|
||||||
|
encoder_attentions=enc_attns,
|
||||||
|
)
|
||||||
|
|||||||
@@ -449,6 +449,13 @@ class TFBartForConditionalGeneration(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFBartForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFBartModel(metaclass=DummyObject):
|
class TFBartModel(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFBartForConditionalGeneration, TFBartModel
|
from transformers import TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@@ -76,7 +78,13 @@ class TFBartModelTester:
|
|||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
# Ids are clipped to avoid "beginng of sequence", "end of sequence", and "pad" tokens
|
||||||
|
input_ids = tf.clip_by_value(
|
||||||
|
ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size),
|
||||||
|
clip_value_min=self.eos_token_id + 1,
|
||||||
|
clip_value_max=self.vocab_size + 1,
|
||||||
|
)
|
||||||
|
# Explicity add "end of sequence" to the inputs
|
||||||
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
@@ -181,7 +189,9 @@ def prepare_bart_inputs_dict(
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
|
class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
|
all_model_classes = (
|
||||||
|
(TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel) if is_tf_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
|||||||
def test_onnx_compliancy(self):
|
def test_onnx_compliancy(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# TFBartForSequenceClassification does not support inputs_embeds
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in (TFBartForConditionalGeneration, TFBartModel):
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
encoder_input_ids = inputs["input_ids"]
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs.pop("decoder_input_ids", None)
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
|
||||||
|
else:
|
||||||
|
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||||
|
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs, model_class)
|
||||||
|
|
||||||
|
model(inputs)
|
||||||
|
|
||||||
|
# TFBartForSequenceClassification does not support inputs_embeds
|
||||||
|
@slow
|
||||||
|
def test_graph_mode_with_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in (TFBartForConditionalGeneration, TFBartModel):
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
encoder_input_ids = inputs["input_ids"]
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs.pop("decoder_input_ids", None)
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
|
||||||
|
else:
|
||||||
|
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||||
|
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs, model_class)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def run_in_graph_mode():
|
||||||
|
return model(inputs)
|
||||||
|
|
||||||
|
outputs = run_in_graph_mode()
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_save_load_after_resize_token_embeddings(self):
|
||||||
|
# Custom version of this test to ensure "end of sequence" tokens are present throughout
|
||||||
|
if not self.test_resize_embeddings:
|
||||||
|
return
|
||||||
|
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# create a model with resized (expended) embeddings
|
||||||
|
new_tokens_size = 10
|
||||||
|
old_total_size = config.vocab_size
|
||||||
|
new_total_size = old_total_size + new_tokens_size
|
||||||
|
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||||
|
model(model.dummy_inputs) # builds the embeddings layer
|
||||||
|
model.resize_token_embeddings(new_total_size)
|
||||||
|
|
||||||
|
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||||
|
inputs_dict = copy.deepcopy(original_inputs_dict)
|
||||||
|
ids_feat_name = None
|
||||||
|
if "input_ids" in inputs_dict:
|
||||||
|
ids_feat_name = "input_ids"
|
||||||
|
elif "decoder_input_ids" in inputs_dict:
|
||||||
|
ids_feat_name = "decoder_input_ids"
|
||||||
|
else:
|
||||||
|
assert False, "No input ids feature found in the inputs dict"
|
||||||
|
|
||||||
|
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
|
||||||
|
new_vocab_input_ids += old_total_size
|
||||||
|
|
||||||
|
# Replace last id with EOS token
|
||||||
|
new_vocab_input_ids = new_vocab_input_ids[:, :-1]
|
||||||
|
new_vocab_input_ids = tf.concat(
|
||||||
|
[new_vocab_input_ids, tf.ones((tf.shape(new_vocab_input_ids)[0], 1), dtype=tf.int32) * 2], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_dict[ids_feat_name] = new_vocab_input_ids
|
||||||
|
if "input_ids" in inputs_dict:
|
||||||
|
inputs_dict["input_ids"] = new_vocab_input_ids
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
|
||||||
|
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
outputs = model(**prepared_inputs)
|
||||||
|
|
||||||
|
# save and load the model
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname, saved_model=False)
|
||||||
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
restored_model_outputs = model(**prepared_inputs)
|
||||||
|
|
||||||
|
# check that the output for the restored model is the same
|
||||||
|
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
def _long_tensor(tok_lst):
|
def _long_tensor(tok_lst):
|
||||||
return tf.constant(tok_lst, dtype=tf.int32)
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
@@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFBartForSequenceClassificationTest(unittest.TestCase):
|
||||||
|
def test_model_fails_for_uneven_eos_tokens(self):
|
||||||
|
config = BartConfig(eos_token_id=2)
|
||||||
|
model = TFBartForSequenceClassification(config)
|
||||||
|
inputs = {
|
||||||
|
"input_ids": tf.constant([[1, 2, 2, 2], [1, 3, 2, 2], [2, 2, 3, 3]]),
|
||||||
|
"attention_mask": tf.constant([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
|
||||||
|
}
|
||||||
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
|
model(inputs)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBartModelIntegrationTest(unittest.TestCase):
|
class TFBartModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user