diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index a2362f9866..d768ddaee7 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -502,6 +502,27 @@ class T5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in lm_labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `lm_labels` has only positive values and -100" + + return shifted_input_ids + class T5Stack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None): @@ -923,6 +944,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): hidden_states = encoder_outputs[0] + if lm_labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(lm_labels) + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -941,10 +966,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here if lm_labels is not None: - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = lm_labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) decoder_outputs = ( loss, ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 9e4a155e2d..97400cf233 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -523,8 +523,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): pad_token_id: (`optional`) int Pad token. Defaults to pad_token_id as defined in the models config. - eos_token_ids: (`optional`) int or list of int - End of sequence token or list of tokens to stop the generation. Default to 0. + eos_token_id: (`optional`) int + EOS token. Defaults to eos_token_id as defined in the models config. length_penalty: (`optional`) float Exponential penalty to the length. Default to 1. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9d4abb2ded..6063fccc41 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -721,13 +721,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): Padding token. Default to specicic model pad_token_id or None if it does not exist. bos_token_id: (`optional`) int - BOS token. Defaults to bos_token_id as defined in the models config. + BOS token. Defaults to `bos_token_id` as defined in the models config. - pad_token_id: (`optional`) int - Pad token. Defaults to pad_token_id as defined in the models config. - - eos_token_ids: (`optional`) int or list of int - End of sequence token or list of tokens to stop the generation. Default to eos_token_ids as defined in the models config. + eos_token_id: (`optional`) int + EOS token. Defaults to `eos_token_id` as defined in the models config. length_penalty: (`optional`) float Exponential penalty to the length. Default to 1. diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index db7ce63317..af59de5aa3 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -468,7 +468,7 @@ class BartModelIntegrationTests(unittest.TestCase): length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True, - decoder_start_token_id=model.config.eos_token_ids[0], + decoder_start_token_id=model.config.eos_token_id, ) decoded = [ diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index c8f9de3cc9..1f3344d1ba 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): + import torch from transformers import T5Config, T5Model, T5ForConditionalGeneration from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP @@ -57,8 +58,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=8, dropout_rate=0.1, initializer_factor=0.002, - eos_token_ids=[1], + eos_token_id=1, pad_token_id=0, + decoder_start_token_id=0, scope=None, ): self.parent = parent @@ -78,8 +80,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): self.dropout_rate = dropout_rate self.initializer_factor = initializer_factor self.scope = scope - self.eos_token_ids = eos_token_ids + self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) @@ -106,9 +109,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, initializer_factor=self.initializer_factor, - eos_token_ids=self.eos_token_ids, + eos_token_id=self.eos_token_id, bos_token_id=self.pad_token_id, pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, ) return ( @@ -123,6 +127,39 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): def check_loss_output(self, result): self.parent.assertListEqual(list(result["loss"].size()), []) + def check_prepare_lm_labels_via_shift_left( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + model = T5Model(config=config) + model.to(torch_device) + model.eval() + + # make sure that lm_labels are correctly padded from the right + lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) + + # add casaul pad token mask + triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() + lm_labels.masked_fill_(triangular_mask, self.pad_token_id) + decoder_input_ids = model._shift_right(lm_labels) + + for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): + # first item + self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) + if i < decoder_input_ids_slice.shape[-1]: + if i < decoder_input_ids.shape[-1] - 1: + # items before diagonal + self.parent.assertListEqual( + decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() + ) + # pad items after diagonal + if i < decoder_input_ids.shape[-1] - 2: + self.parent.assertListEqual( + decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() + ) + else: + # all items after square + self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) + def create_and_check_t5_model( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): @@ -197,6 +234,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + def test_t5_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_model(*config_and_inputs) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 731de2540d..db3f0ac08f 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -52,7 +52,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=8, dropout_rate=0.1, initializer_factor=0.002, - eos_token_ids=[1], + eos_token_id=1, pad_token_id=0, scope=None, ): @@ -71,7 +71,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): self.relative_attention_num_buckets = relative_attention_num_buckets self.dropout_rate = dropout_rate self.initializer_factor = initializer_factor - self.eos_token_ids = eos_token_ids + self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.scope = scope @@ -97,7 +97,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, initializer_factor=self.initializer_factor, - eos_token_ids=self.eos_token_ids, + eos_token_id=self.eos_token_id, bos_token_id=self.pad_token_id, pad_token_id=self.pad_token_id, )