Support T5 Generation (#3228)

* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Patrick von Platen
2020-03-19 23:18:23 +01:00
committed by GitHub
parent 656e1386a2
commit bbf26c4e61
16 changed files with 449 additions and 280 deletions

View File

@@ -147,7 +147,7 @@ class ModelTesterMixin:
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
)
decoder_attention_idx = 1
if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen)
@@ -601,9 +601,9 @@ class ModelTesterMixin:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["encoder_input_ids"]
encoder_input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
del inputs_dict["encoder_input_ids"]
del inputs_dict["input_ids"]
inputs_dict.pop("decoder_input_ids", None)
for model_class in self.all_model_classes:
@@ -615,7 +615,7 @@ class ModelTesterMixin:
if not self.is_encoder_decoder:
inputs_dict["inputs_embeds"] = wte(input_ids)
else:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
@@ -624,9 +624,7 @@ class ModelTesterMixin:
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
input_ids = inputs_dict.get("input_ids")
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models