From 333affcb8120f9918a2c49042c2a6dc92f954999 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 6 Mar 2020 15:14:36 +0100 Subject: [PATCH] add current changes --- src/transformers/modeling_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b20ca0b657..e9b8ee59f9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -614,6 +614,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): max_length=None, min_length=None, do_sample=True, + early_stopping=False, num_beams=None, temperature=None, top_k=None, @@ -720,7 +721,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length - do_sample = do_sample if do_sample is not None else self.config.do_sample + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beams = num_beams if num_beams is not None else self.config.num_beams temperature = temperature if temperature is not None else self.config.temperature top_k = top_k if top_k is not None else self.config.top_k @@ -747,6 +748,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert temperature > 0, "`temperature` should be strictly positive." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." @@ -841,8 +843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), -# eos_token_id, - bos_token_id, + eos_token_id, +# bos_token_id, # eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case dtype=torch.long, device=next(self.parameters()).device, @@ -860,6 +862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): max_length, min_length, do_sample, + early_stopping, temperature, top_k, top_p, @@ -1012,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): max_length, min_length, do_sample, + early_stopping, temperature, top_k, top_p, @@ -1033,7 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # generated hypotheses generated_hyps = [ - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) + BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size) ] # scores for each sentence in the beam @@ -1080,11 +1084,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # force eos to be chosen at end of generation for encoder-decoder models # TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently if self.config.is_encoder_decoder: -# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length) if cur_len == 1: self._force_token_ids_generation(next_token_logits, bos_token_id) if cur_len == max_length - 1: self._force_token_ids_generation(next_token_logits, eos_token_ids) +# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length) if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens)