Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -148,10 +148,12 @@ class TFModelTesterMixin:
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
config.output_hidden_states = True
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
@@ -221,7 +223,7 @@ class TFModelTesterMixin:
|
||||
if self.is_encoder_decoder:
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
||||
@@ -393,9 +395,9 @@ class TFModelTesterMixin:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
encoder_input_ids = inputs_dict["input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -405,7 +407,7 @@ class TFModelTesterMixin:
|
||||
if not self.is_encoder_decoder:
|
||||
inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids)
|
||||
else:
|
||||
inputs_dict["encoder_inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs_dict["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
|
||||
|
||||
model(inputs_dict)
|
||||
@@ -413,9 +415,10 @@ class TFModelTesterMixin:
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
"input_ids", None
|
||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user