From a332cc9f7f5009ae53d2dd66507d8a7710dc7ba7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Mar 2020 11:53:36 +0100 Subject: [PATCH] finalize generation merge --- src/transformers/configuration_bart.py | 6 +++--- src/transformers/modeling_bart.py | 6 +++--- src/transformers/modeling_utils.py | 3 --- tests/test_modeling_bart.py | 8 ++++---- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 03791c42bb..f6733a9bc4 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -40,8 +40,9 @@ class BartConfig(PretrainedConfig): self, activation_dropout=0.0, vocab_size=50265, + bos_token_id=0, pad_token_id=1, - eos_token_id=2, + eos_token_ids=[2], d_model=1024, encoder_ffn_dim=4096, encoder_layers=12, @@ -58,7 +59,6 @@ class BartConfig(PretrainedConfig): classifier_dropout=0.0, output_past=False, num_labels=3, - bos_token_id=0, is_encoder_decoder=True, **common_kwargs ): @@ -73,12 +73,12 @@ class BartConfig(PretrainedConfig): output_past=output_past, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + eos_token_ids=eos_token_ids, is_encoder_decoder=is_encoder_decoder, **common_kwargs, ) self.vocab_size = vocab_size self.d_model = d_model # encoder_embed_dim and decoder_embed_dim - self.eos_token_id = eos_token_id self.encoder_ffn_dim = encoder_ffn_dim self.encoder_layers = self.num_hidden_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index bb32b1c96c..2bea8e6fc8 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel): def prepare_scores_for_generation(self, scores, cur_len, max_length): if cur_len == 1: self._force_token_ids_generation(scores, self.config.bos_token_id) - if cur_len == max_length - 1: - self._force_token_ids_generation(scores, self.config.eos_token_ids) + if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None: + self._force_token_ids_generation(scores, self.config.eos_token_ids[0]) return scores @staticmethod @@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel): encoder_outputs=encoder_outputs, ) x = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_ids[0]) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 62c6f9950d..6e26a9318d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - eos_token_id = eos_token_ids[0] assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" - assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id" # encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - # eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this bos_token_id, dtype=torch.long, device=next(self.parameters()).device, diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index b23f01066e..ada32ab647 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -82,7 +82,7 @@ class ModelTester: dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, - eos_token_ids=self.eos_token_id, + eos_token_ids=[self.eos_token_id], bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) @@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, output_past=output_past, - eos_token_id=2, + eos_token_ids=[2], pad_token_id=1, bos_token_id=0, ) @@ -276,7 +276,7 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, output_past=True, - eos_token_ids=2, + eos_token_ids=[2], pad_token_id=1, bos_token_id=0, ) @@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase): new_input_ids = lm_model.generate( input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length ) - self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length)) + self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1)) # TODO(SS): uneven length batches, empty inputs def test_shift_tokens_right(self):