Add TFBartForSequenceClassification (#20570)

* read to load

* base functionality

* revert init

* fix dummy data

* moving right along

* moving right along

* finally

* cleanup

* pull out comment

* add test

* update docstring for main class

* flake comments and rewriting copies from make repo-consistency`

* remove irrelevant differences/accidental spaces

* put copies back after space removals

* mid

* final test pass

* stray comment

* update test file

* update test file

* fixup

* black

* missed

* black missed one more

* sytle

* add doc update

* fix order of output class

* comment

* Revert "comment"

This reverts commit 03f86b6948808461939cc8ad4ad74305dfb67700.

* remove redundant function, and redundant reshape

* move change out of common

* style

* put common spaces back

* reorder kwargs in output

* doc style
This commit is contained in:
Cole Howard
2022-12-07 09:05:39 -08:00
committed by GitHub
parent 77382e918d
commit fc95386ea1
10 changed files with 338 additions and 9 deletions

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
import numpy as np
@@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFBartForConditionalGeneration, TFBartModel
from transformers import TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel
@require_tf
@@ -76,7 +78,13 @@ class TFBartModelTester:
self.bos_token_id = bos_token_id
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
# Ids are clipped to avoid "beginng of sequence", "end of sequence", and "pad" tokens
input_ids = tf.clip_by_value(
ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size),
clip_value_min=self.eos_token_id + 1,
clip_value_max=self.vocab_size + 1,
)
# Explicity add "end of sequence" to the inputs
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
@@ -181,7 +189,9 @@ def prepare_bart_inputs_dict(
@require_tf
class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
all_model_classes = (
(TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel) if is_tf_available() else ()
)
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
@@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
def test_onnx_compliancy(self):
pass
# TFBartForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in (TFBartForConditionalGeneration, TFBartModel):
model = model_class(config)
inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
model(inputs)
# TFBartForSequenceClassification does not support inputs_embeds
@slow
def test_graph_mode_with_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in (TFBartForConditionalGeneration, TFBartModel):
model = model_class(config)
inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_save_load_after_resize_token_embeddings(self):
# Custom version of this test to ensure "end of sequence" tokens are present throughout
if not self.test_resize_embeddings:
return
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# create a model with resized (expended) embeddings
new_tokens_size = 10
old_total_size = config.vocab_size
new_total_size = old_total_size + new_tokens_size
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
model(model.dummy_inputs) # builds the embeddings layer
model.resize_token_embeddings(new_total_size)
# fetch the output for an input exclusively made of new members of the vocabulary
inputs_dict = copy.deepcopy(original_inputs_dict)
ids_feat_name = None
if "input_ids" in inputs_dict:
ids_feat_name = "input_ids"
elif "decoder_input_ids" in inputs_dict:
ids_feat_name = "decoder_input_ids"
else:
assert False, "No input ids feature found in the inputs dict"
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
new_vocab_input_ids += old_total_size
# Replace last id with EOS token
new_vocab_input_ids = new_vocab_input_ids[:, :-1]
new_vocab_input_ids = tf.concat(
[new_vocab_input_ids, tf.ones((tf.shape(new_vocab_input_ids)[0], 1), dtype=tf.int32) * 2], axis=1
)
inputs_dict[ids_feat_name] = new_vocab_input_ids
if "input_ids" in inputs_dict:
inputs_dict["input_ids"] = new_vocab_input_ids
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs)
# save and load the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname)
restored_model_outputs = model(**prepared_inputs)
# check that the output for the restored model is the same
self.assert_outputs_same(restored_model_outputs, outputs)
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
@@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape)
@require_tf
class TFBartForSequenceClassificationTest(unittest.TestCase):
def test_model_fails_for_uneven_eos_tokens(self):
config = BartConfig(eos_token_id=2)
model = TFBartForSequenceClassification(config)
inputs = {
"input_ids": tf.constant([[1, 2, 2, 2], [1, 3, 2, 2], [2, 2, 3, 3]]),
"attention_mask": tf.constant([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
}
with self.assertRaises(tf.errors.InvalidArgumentError):
model(inputs)
@slow
@require_tf
class TFBartModelIntegrationTest(unittest.TestCase):