Add TFBartForConditionalGeneration (#5411)
* half done * doc improvement * Cp test file * brokedn * broken test * undo some mess * ckpt * borked * Halfway * 6 passing * boom boom * Much progress but still 6 * boom boom * merged master * 10 passing * boom boom * Style * no t5 changes * 13 passing * Integration test failing, but not gibberish * Frustrated * Merged master * 4 fail * 4 fail * fix return_dict * boom boom * Still only 4 * prepare method * prepare method * before delete classif * Skip tests to avoid adding boilerplate * boom boom * fast tests passing * style * boom boom * Switch to supporting many input types * remove FIXMENORM * working * Fixed past_key_values/decoder_cached_states confusion * new broken test * Fix attention mask kwarg name * undo accidental * Style and reviewers * style * Docs and common tests * Cleaner assert messages * copy docs * style issues * Sphinx fix * Simplify caching logic * test does not require torch * copy _NoLayerEmbedTokens * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update tests/test_modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Line length and dont document None * Add pipeline test coverage * assert msg * At parity * Assert messages * mark slow * Update compile test * back in init * Merge master * Fix tests Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
File diff suppressed because one or more lines are too long
357
tests/test_modeling_tf_bart.py
Normal file
357
tests/test_modeling_tf_bart.py
Normal file
File diff suppressed because one or more lines are too long
@@ -302,7 +302,7 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
config.output_hidden_states = True
|
||||
@@ -472,10 +472,9 @@ class TFModelTesterMixin:
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
@@ -494,7 +493,9 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs_dict = model(inputs)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_ids = inputs_keywords.pop("input_ids", None)
|
||||
@@ -507,28 +508,18 @@ class TFModelTesterMixin:
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
decoder_seq_length = (
|
||||
self.model_tester.decoder_seq_length
|
||||
if hasattr(self.model_tester, "decoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
encoder_seq_length = (
|
||||
self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "encoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
decoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
|
||||
)
|
||||
encoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
||||
)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["use_cache"] = False
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(model_inputs)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
@@ -279,9 +279,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ["t5-small"]:
|
||||
model = TFT5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
model = TFT5Model.from_pretrained("t5-small")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
|
||||
@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"]
|
||||
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
|
||||
|
||||
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||
|
||||
TRANSLATION_FINETUNED_MODELS = [
|
||||
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
|
||||
|
||||
Reference in New Issue
Block a user