From c0d9dd3ba96fa4f83a9c3c566dff01d7a0a6608b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 5 Mar 2020 14:56:56 +0100 Subject: [PATCH] refactored code a bit and made more generic --- src/transformers/configuration_utils.py | 1 + src/transformers/modeling_utils.py | 39 ++++++++++++++++--------- tests/test_modeling_bart.py | 2 +- tests/test_modeling_common.py | 2 +- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1ada6ade8b..09ef959dd3 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -69,6 +69,7 @@ class PretrainedConfig(object): # Parameters for sequence generation self.max_length = kwargs.pop("max_length", 20) + self.min_length = kwargs.pop("max_length", 0) self.do_sample = kwargs.pop("do_sample", False) self.early_stopping = kwargs.pop("early_stopping", False) self.num_beams = kwargs.pop("num_beams", 1) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 74d20948c0..93347b4525 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): self, input_ids=None, max_length=None, + min_length=None, do_sample=True, num_beams=None, temperature=None, @@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ) max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length do_sample = do_sample if do_sample is not None else self.config.do_sample num_beams = num_beams if num_beams is not None else self.config.num_beams temperature = temperature if temperature is not None else self.config.temperature @@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): eos_token_ids = [eos_token_ids] assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert temperature > 0, "`temperature` should be strictly positive." @@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), -# eos_token_id, +# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no? bos_token_id, dtype=torch.long, device=next(self.parameters()).device, ) - cur_len = 0 + cur_len = 1 self.model.decoder.generation_mode = True else: encoder_inputs = None @@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, temperature, top_k, @@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, temperature, top_k, @@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, temperature, top_k, @@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if repetition_penalty != 1.0: self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty) + if eos_token_ids is not None and cur_len < min_length: + for eos_token_id in eos_token_ids: + next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks + if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: @@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, temperature, top_k, @@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_token_logits, batch_size, num_beams, input_ids, repetition_penalty ) + if eos_token_ids is not None and cur_len < min_length: + for eos_token_id in eos_token_ids: + next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks + if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: @@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if is_encoder_decoder: # TODO(PVP) to be refactored later - import math -# scores[scores != scores] = -math.inf # block nans -# scores[:, pad_token_id] = -math.inf +# scores[scores != scores] = -math.inf # block nans => seems very hacky here +# scores[:, pad_token_id] = -math.inf => seems very hacky here # TODO(SS): fairseq also takes out every step, and has unk at slot 3 -# if cur_len == 0: # Force BOS to be chosen -# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here -# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len -# scores[:, eos_token_ids[0]] = -math.inf -# elif cur_len == max_length: # FORCE EOS to be chosen - if cur_len == max_length: # FORCE EOS to be chosen - scores[:, :eos_token_ids[0]] = -math.inf - scores[:, eos_token_ids[0] + 1 :] = -math.inf +# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line +# scores[:, self.config.bos_token_id + 1 :] = -math.inf + if cur_len == max_length - 1: # FORCE EOS to be chosen + all_but_eos_mask = torch.tensor([x for x in range(vocab_size) if x not in eos_token_ids], dtype=torch.long, device=next(self.parameters()).device) + scores[:, all_but_eos_mask] = -10000.0 assert scores.size() == (batch_size * num_beams, vocab_size) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) @@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # shorter batches are filled with pad_token if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`Pad_token_id` has to be defined" - sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) + sent_max_len = min(sent_lengths.max().item() + 1, max_length) decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) # fill with hypothesis and eos_token_id if necessary diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 3a335fabb9..ffc5a004f7 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase): tokens = tok.encode(text, return_tensors="pt").to(torch_device) extra_len = 20 gen_tokens_1 = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10., - gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len, do_sample=False) # repetition_penalty=10., + gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len + 2, do_sample=False) # repetition_penalty=10., print("1: {}".format(gen_tokens_1)) print("2: {}".format(gen_tokens)) ipdb.set_trace() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ce3bda8b45..4abc183218 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -621,7 +621,7 @@ class ModelTesterMixin: with torch.no_grad(): model(**inputs_dict) - def _A_test_lm_head_model_random_generate(self): + def test_lm_head_model_random_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict.get(