New TF model inputs (#8602)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add input processing for TF Flaubert * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add the new inputs in new Longformer models * Update the template with the new input processing * Remove useless assert * Apply style * Trigger CI
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
model_tester_cls = TFBartModelTester
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
context = tf.fill((7, 2), 4)
|
||||
summary = tf.fill((7, 7), 6)
|
||||
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False)
|
||||
outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user