From 6047f46b199ba49f353b31d7bedad2b3e076f52e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Mar 2020 20:17:50 +0100 Subject: [PATCH] re-add eos token to get good bart results --- src/transformers/modeling_utils.py | 10 +++++++++- tests/test_modeling_bart.py | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 253844ad46..57b4204a53 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size=None, num_return_sequences=None, attention_mask=None, + decoder_start_token_id=None, ): r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling and beam-search. @@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) + # TODO: think about how to make this cleaner + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id + ) if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size @@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert (eos_token_ids is None) or ( isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert ( + decoder_start_token_id is not None or self.config.is_encoder_decoder is False + ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" assert length_penalty > 0, "`length_penalty` should be strictly positive." assert ( isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 @@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - bos_token_id, # TODO: wait for results of Bart CNN summarization + decoder_start_token_id, # TODO: see whether this is the best result dtype=torch.long, device=next(self.parameters()).device, ) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index b885ccf1b6..8ad02feb2a 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -432,7 +432,11 @@ class BartModelIntegrationTest(unittest.TestCase): tokens = tok.encode(text, return_tensors="pt").to(torch_device) extra_len = 20 gen_tokens = hf.generate( - tokens, num_beams=4, max_length=extra_len + 2, do_sample=False + tokens, + num_beams=4, + max_length=extra_len + 2, + do_sample=False, + decoder_start_token_id=hf.config.eos_token_id, ) # repetition_penalty=10., expected_result = "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." generated = [tok.decode(g,) for g in gen_tokens] @@ -477,6 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase): no_repeat_ngram_size=3, do_sample=False, early_stopping=True, + decoder_start_token_id=hf.config.eos_token_id, ) decoded = [