New TF model inputs (#8602)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add input processing for TF Flaubert * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add the new inputs in new Longformer models * Update the template with the new input processing * Remove useless assert * Apply style * Trigger CI
This commit is contained in:
@@ -152,7 +152,7 @@ class TFModelTesterMixin:
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"inputs",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
@@ -161,7 +161,7 @@ class TFModelTesterMixin:
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
|
||||
else:
|
||||
expected_arg_names = ["inputs"]
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@@ -753,7 +753,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user