Fix decode_input_ids to bare T5Model and improve doc (#18791)
* use tokenizer to output tensor * add preprocessing for decoder_input_ids for bare T5Model * add preprocessing to tf and flax * linting * linting * Update src/transformers/models/t5/modeling_flax_t5.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/models/t5/modeling_tf_t5.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -187,12 +187,15 @@ ignored. The code example below illustrates all of this.
|
|||||||
|
|
||||||
>>> # encode the targets
|
>>> # encode the targets
|
||||||
>>> target_encoding = tokenizer(
|
>>> target_encoding = tokenizer(
|
||||||
... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
|
... [output_sequence_1, output_sequence_2],
|
||||||
|
... padding="longest",
|
||||||
|
... max_length=max_target_length,
|
||||||
|
... truncation=True,
|
||||||
|
... return_tensors="pt",
|
||||||
... )
|
... )
|
||||||
>>> labels = target_encoding.input_ids
|
>>> labels = target_encoding.input_ids
|
||||||
|
|
||||||
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
|
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
|
||||||
>>> labels = torch.tensor(labels)
|
|
||||||
>>> labels[labels == tokenizer.pad_token_id] = -100
|
>>> labels[labels == tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
|
|||||||
@@ -1388,6 +1388,10 @@ FLAX_T5_MODEL_DOCSTRING = """
|
|||||||
... ).input_ids
|
... ).input_ids
|
||||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
||||||
|
|
||||||
|
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
|
||||||
|
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
|
||||||
|
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
|||||||
@@ -1383,6 +1383,10 @@ class T5Model(T5PreTrainedModel):
|
|||||||
... ).input_ids # Batch size 1
|
... ).input_ids # Batch size 1
|
||||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
||||||
|
|
||||||
|
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
|
||||||
|
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
|
||||||
|
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
|||||||
@@ -1180,6 +1180,10 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
... ).input_ids # Batch size 1
|
... ).input_ids # Batch size 1
|
||||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
||||||
|
|
||||||
|
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
|
||||||
|
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
|
||||||
|
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
|||||||
Reference in New Issue
Block a user