[T5] make decoder input ids optional for t5 training (#3521)
* make decoder input ids optional for t5 training * lm_lables should not be shifted in t5 * add tests * finish shift right functionality for PT T5 * move shift right to correct class * cleaner code * replace -100 values with pad token id * add assert statement * remove unnecessary for loop * make style
This commit is contained in:
committed by
GitHub
parent
5b44e0a31b
commit
75ec6c9e3a
@@ -502,6 +502,27 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
if module.has_relative_attention_bias:
|
||||
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
||||
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.config.decoder_start_token_id
|
||||
pad_token_id = self.config.pad_token_id
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||
# replace possible -100 values in lm_labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||
|
||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `lm_labels` has only positive values and -100"
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
class T5Stack(T5PreTrainedModel):
|
||||
def __init__(self, config, embed_tokens=None):
|
||||
@@ -923,6 +944,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if lm_labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
# get decoder inputs from shifting lm labels to the right
|
||||
decoder_input_ids = self._shift_right(lm_labels)
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
@@ -941,10 +966,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
decoder_outputs = (
|
||||
loss,
|
||||
) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
@@ -523,8 +523,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
pad_token_id: (`optional`) int
|
||||
Pad token. Defaults to pad_token_id as defined in the models config.
|
||||
|
||||
eos_token_ids: (`optional`) int or list of int
|
||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||
eos_token_id: (`optional`) int
|
||||
EOS token. Defaults to eos_token_id as defined in the models config.
|
||||
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
@@ -721,13 +721,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
Padding token. Default to specicic model pad_token_id or None if it does not exist.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
BOS token. Defaults to bos_token_id as defined in the models config.
|
||||
BOS token. Defaults to `bos_token_id` as defined in the models config.
|
||||
|
||||
pad_token_id: (`optional`) int
|
||||
Pad token. Defaults to pad_token_id as defined in the models config.
|
||||
|
||||
eos_token_ids: (`optional`) int or list of int
|
||||
End of sequence token or list of tokens to stop the generation. Default to eos_token_ids as defined in the models config.
|
||||
eos_token_id: (`optional`) int
|
||||
EOS token. Defaults to `eos_token_id` as defined in the models config.
|
||||
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
Reference in New Issue
Block a user