New TF model inputs (#8602)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add input processing for TF Flaubert * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add the new inputs in new Longformer models * Update the template with the new input processing * Remove useless assert * Apply style * Trigger CI
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
model_tester_cls = TFBartModelTester
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
context = tf.fill((7, 2), 4)
|
||||
summary = tf.fill((7, 7), 6)
|
||||
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False)
|
||||
outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@@ -12,24 +12,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
||||
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
|
||||
from transformers import (
|
||||
BlenderbotConfig,
|
||||
BlenderbotSmallTokenizer,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFBlenderbotForConditionalGeneration,
|
||||
is_tf_available,
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
class TFBlenderbotModelTester(TFBartModelTester):
|
||||
config_updates = dict(
|
||||
normalize_before=True,
|
||||
static_position_embeddings=True,
|
||||
@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_config(self):
|
||||
@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_tokenizers
|
||||
|
||||
@@ -152,7 +152,7 @@ class TFModelTesterMixin:
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"inputs",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
@@ -161,7 +161,7 @@ class TFModelTesterMixin:
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
|
||||
else:
|
||||
expected_arg_names = ["inputs"]
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@@ -753,7 +753,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user