add current changes
This commit is contained in:
@@ -614,6 +614,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
do_sample=True,
|
||||
early_stopping=False,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
@@ -720,7 +721,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
@@ -747,6 +748,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictly positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
@@ -841,8 +843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
# eos_token_id,
|
||||
bos_token_id,
|
||||
eos_token_id,
|
||||
# bos_token_id,
|
||||
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
@@ -860,6 +862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
early_stopping,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
@@ -1012,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
early_stopping,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
@@ -1033,7 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
@@ -1080,11 +1084,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# force eos to be chosen at end of generation for encoder-decoder models
|
||||
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
|
||||
if self.config.is_encoder_decoder:
|
||||
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(next_token_logits, bos_token_id)
|
||||
if cur_len == max_length - 1:
|
||||
self._force_token_ids_generation(next_token_logits, eos_token_ids)
|
||||
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
|
||||
Reference in New Issue
Block a user