Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -148,10 +148,12 @@ class TFModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
@@ -221,7 +223,7 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
@@ -393,9 +395,9 @@ class TFModelTesterMixin:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["encoder_input_ids"]
encoder_input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
del inputs_dict["encoder_input_ids"]
del inputs_dict["input_ids"]
del inputs_dict["decoder_input_ids"]
for model_class in self.all_model_classes:
@@ -405,7 +407,7 @@ class TFModelTesterMixin:
if not self.is_encoder_decoder:
inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids)
else:
inputs_dict["encoder_inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
inputs_dict["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
model(inputs_dict)
@@ -413,9 +415,10 @@ class TFModelTesterMixin:
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
input_ids = inputs_dict["input_ids"]
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
for model_class in self.all_generative_model_classes:
model = model_class(config)