From f85acb4d73a84fe9bee5279068b0430fc391fb36 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Tue, 6 Sep 2022 17:42:26 +0530 Subject: [PATCH] Fix decode_input_ids to bare T5Model and improve doc (#18791) * use tokenizer to output tensor * add preprocessing for decoder_input_ids for bare T5Model * add preprocessing to tf and flax * linting * linting * Update src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Patrick von Platen * Update src/transformers/models/t5/modeling_tf_t5.py Co-authored-by: Patrick von Platen * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen --- docs/source/en/model_doc/t5.mdx | 7 +++++-- src/transformers/models/t5/modeling_flax_t5.py | 4 ++++ src/transformers/models/t5/modeling_t5.py | 4 ++++ src/transformers/models/t5/modeling_tf_t5.py | 4 ++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/t5.mdx b/docs/source/en/model_doc/t5.mdx index 5a19289234..92cd753b64 100644 --- a/docs/source/en/model_doc/t5.mdx +++ b/docs/source/en/model_doc/t5.mdx @@ -187,12 +187,15 @@ ignored. The code example below illustrates all of this. >>> # encode the targets >>> target_encoding = tokenizer( -... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True +... [output_sequence_1, output_sequence_2], +... padding="longest", +... max_length=max_target_length, +... truncation=True, +... return_tensors="pt", ... ) >>> labels = target_encoding.input_ids >>> # replace padding token id's of the labels by -100 so it's ignored by the loss ->>> labels = torch.tensor(labels) >>> labels[labels == tokenizer.pad_token_id] = -100 >>> # forward pass diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 918a605fc4..2732bf5916 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -1388,6 +1388,10 @@ FLAX_T5_MODEL_DOCSTRING = """ ... ).input_ids >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1e70ba773a..8e414cbd7a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1383,6 +1383,10 @@ class T5Model(T5PreTrainedModel): ... ).input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 091cb9d63e..dc909c8d8f 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1180,6 +1180,10 @@ class TFT5Model(TFT5PreTrainedModel): ... ).input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state