add current changes

This commit is contained in:
Patrick von Platen
2020-03-06 15:14:36 +01:00
parent 421216997b
commit 333affcb81

View File

@@ -614,6 +614,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length=None, max_length=None,
min_length=None, min_length=None,
do_sample=True, do_sample=True,
early_stopping=False,
num_beams=None, num_beams=None,
temperature=None, temperature=None,
top_k=None, top_k=None,
@@ -720,7 +721,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k top_k = top_k if top_k is not None else self.config.top_k
@@ -747,6 +748,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive." assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
@@ -841,8 +843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
# eos_token_id, eos_token_id,
bos_token_id, # bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case # eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
@@ -860,6 +862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length, max_length,
min_length, min_length,
do_sample, do_sample,
early_stopping,
temperature, temperature,
top_k, top_k,
top_p, top_p,
@@ -1012,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length, max_length,
min_length, min_length,
do_sample, do_sample,
early_stopping,
temperature, temperature,
top_k, top_k,
top_p, top_p,
@@ -1033,7 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
] ]
# scores for each sentence in the beam # scores for each sentence in the beam
@@ -1080,11 +1084,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# force eos to be chosen at end of generation for encoder-decoder models # force eos to be chosen at end of generation for encoder-decoder models
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently # TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if cur_len == 1: if cur_len == 1:
self._force_token_ids_generation(next_token_logits, bos_token_id) self._force_token_ids_generation(next_token_logits, bos_token_id)
if cur_len == max_length - 1: if cur_len == max_length - 1:
self._force_token_ids_generation(next_token_logits, eos_token_ids) self._force_token_ids_generation(next_token_logits, eos_token_ids)
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)