Add TFBartForSequenceClassification (#20570)
* read to load * base functionality * revert init * fix dummy data * moving right along * moving right along * finally * cleanup * pull out comment * add test * update docstring for main class * flake comments and rewriting copies from make repo-consistency` * remove irrelevant differences/accidental spaces * put copies back after space removals * mid * final test pass * stray comment * update test file * update test file * fixup * black * missed * black missed one more * sytle * add doc update * fix order of output class * comment * Revert "comment" This reverts commit 03f86b6948808461939cc8ad4ad74305dfb67700. * remove redundant function, and redundant reshape * move change out of common * style * put common spaces back * reorder kwargs in output * doc style
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFBartForConditionalGeneration, TFBartModel
|
||||
from transformers import TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel
|
||||
|
||||
|
||||
@require_tf
|
||||
@@ -76,7 +78,13 @@ class TFBartModelTester:
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||
# Ids are clipped to avoid "beginng of sequence", "end of sequence", and "pad" tokens
|
||||
input_ids = tf.clip_by_value(
|
||||
ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size),
|
||||
clip_value_min=self.eos_token_id + 1,
|
||||
clip_value_max=self.vocab_size + 1,
|
||||
)
|
||||
# Explicity add "end of sequence" to the inputs
|
||||
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||
|
||||
@@ -181,7 +189,9 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
@require_tf
|
||||
class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
|
||||
all_model_classes = (
|
||||
(TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel) if is_tf_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
@@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
def test_onnx_compliancy(self):
|
||||
pass
|
||||
|
||||
# TFBartForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (TFBartForConditionalGeneration, TFBartModel):
|
||||
model = model_class(config)
|
||||
|
||||
inputs = copy.deepcopy(inputs_dict)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||
|
||||
inputs = self._prepare_for_class(inputs, model_class)
|
||||
|
||||
model(inputs)
|
||||
|
||||
# TFBartForSequenceClassification does not support inputs_embeds
|
||||
@slow
|
||||
def test_graph_mode_with_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (TFBartForConditionalGeneration, TFBartModel):
|
||||
model = model_class(config)
|
||||
|
||||
inputs = copy.deepcopy(inputs_dict)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||
|
||||
inputs = self._prepare_for_class(inputs, model_class)
|
||||
|
||||
@tf.function
|
||||
def run_in_graph_mode():
|
||||
return model(inputs)
|
||||
|
||||
outputs = run_in_graph_mode()
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@slow
|
||||
def test_save_load_after_resize_token_embeddings(self):
|
||||
# Custom version of this test to ensure "end of sequence" tokens are present throughout
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
# create a model with resized (expended) embeddings
|
||||
new_tokens_size = 10
|
||||
old_total_size = config.vocab_size
|
||||
new_total_size = old_total_size + new_tokens_size
|
||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||
model(model.dummy_inputs) # builds the embeddings layer
|
||||
model.resize_token_embeddings(new_total_size)
|
||||
|
||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||
inputs_dict = copy.deepcopy(original_inputs_dict)
|
||||
ids_feat_name = None
|
||||
if "input_ids" in inputs_dict:
|
||||
ids_feat_name = "input_ids"
|
||||
elif "decoder_input_ids" in inputs_dict:
|
||||
ids_feat_name = "decoder_input_ids"
|
||||
else:
|
||||
assert False, "No input ids feature found in the inputs dict"
|
||||
|
||||
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
|
||||
new_vocab_input_ids += old_total_size
|
||||
|
||||
# Replace last id with EOS token
|
||||
new_vocab_input_ids = new_vocab_input_ids[:, :-1]
|
||||
new_vocab_input_ids = tf.concat(
|
||||
[new_vocab_input_ids, tf.ones((tf.shape(new_vocab_input_ids)[0], 1), dtype=tf.int32) * 2], axis=1
|
||||
)
|
||||
|
||||
inputs_dict[ids_feat_name] = new_vocab_input_ids
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs_dict["input_ids"] = new_vocab_input_ids
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
|
||||
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs)
|
||||
|
||||
# save and load the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
restored_model_outputs = model(**prepared_inputs)
|
||||
|
||||
# check that the output for the restored model is the same
|
||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
@@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBartForSequenceClassificationTest(unittest.TestCase):
|
||||
def test_model_fails_for_uneven_eos_tokens(self):
|
||||
config = BartConfig(eos_token_id=2)
|
||||
model = TFBartForSequenceClassification(config)
|
||||
inputs = {
|
||||
"input_ids": tf.constant([[1, 2, 2, 2], [1, 3, 2, 2], [2, 2, 3, 3]]),
|
||||
"attention_mask": tf.constant([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
|
||||
}
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
model(inputs)
|
||||
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
class TFBartModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user