From 1529bf96805db3461ed26f7ffae1ca8b79ee278a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 18 Aug 2020 19:15:50 -0400 Subject: [PATCH] add BartConfig.force_bos_token_to_be_generated (#6526) Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_bart.py | 5 +++++ src/transformers/modeling_bart.py | 18 +++++------------- src/transformers/modeling_marian.py | 2 +- tests/test_modeling_bart.py | 2 +- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 8edcb4e69c..4840998295 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -95,6 +95,8 @@ BART_CONFIG_ARGS_DOC = r""" for SequenceClassification is_encoder_decoder (:obj:`int`, optional, defaults to True): True + force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only true for `bart-large-cnn`. """ @@ -137,6 +139,7 @@ class BartConfig(PretrainedConfig): normalize_embedding=True, static_position_embeddings=False, add_bias_logits=False, + force_bos_token_to_be_generated=False, **common_kwargs ): r""" @@ -195,6 +198,8 @@ class BartConfig(PretrainedConfig): # pos embedding offset self.extra_pos_embeddings = self.pad_token_id + 1 + self.force_bos_token_to_be_generated = force_bos_token_to_be_generated + @property def num_attention_heads(self) -> int: return self.encoder_attention_heads diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 93f0e96ab8..e8c107be7f 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -1073,23 +1073,15 @@ class BartForConditionalGeneration(PretrainedBartModel): } def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == 1: + if cur_len == 1 and self.config.force_bos_token_to_be_generated: self._force_token_ids_generation(logits, self.config.bos_token_id) - if cur_len == max_length - 1 and self.config.eos_token_id is not None: + elif cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_ids_generation(logits, self.config.eos_token_id) return logits - def _force_token_ids_generation(self, scores, token_ids) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0""" - if isinstance(token_ids, int): - token_ids = [token_ids] - all_but_token_ids_mask = torch.tensor( - [x for x in range(self.config.vocab_size) if x not in token_ids], - dtype=torch.long, - device=next(self.parameters()).device, - ) - assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" - scores[:, all_but_token_ids_mask] = -float("inf") + def _force_token_ids_generation(self, scores, token_id) -> None: + """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" + scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") @staticmethod def _reorder_cache(past, beam_idx): diff --git a/src/transformers/modeling_marian.py b/src/transformers/modeling_marian.py index bde0c62788..97fe451a38 100644 --- a/src/transformers/modeling_marian.py +++ b/src/transformers/modeling_marian.py @@ -47,7 +47,7 @@ class MarianMTModel(BartForConditionalGeneration): """ def adjust_logits_during_generation(self, logits, cur_len, max_length): - logits[:, self.config.pad_token_id] = float("-inf") + logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. if cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_ids_generation(logits, self.config.eos_token_id) return logits diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 886848618f..11dea766a4 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -484,7 +484,7 @@ class BartModelIntegrationTests(unittest.TestCase): self.assertFalse(model.config.is_valid_mbart()) tok = BartTokenizer.from_pretrained("facebook/bart-large") - EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state." + EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state." dct = tok.batch_encode_plus( [PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt", ).to(torch_device)