Remove max length beam scorer (#11378)

* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Ashwin Geet D'Sa
2021-04-27 00:28:40 +02:00
committed by GitHub
parent bc2571e61c
commit 741d48f5c7
6 changed files with 91 additions and 38 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -110,6 +111,7 @@ class BeamScorer(ABC):
next_scores: torch.FloatTensor, next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor, next_tokens: torch.LongTensor,
next_indices: torch.LongTensor, next_indices: torch.LongTensor,
max_length: int,
**kwargs **kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.") raise NotImplementedError("This is an abstract method.")
@@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
def __init__( def __init__(
self, self,
batch_size: int, batch_size: int,
max_length: int,
num_beams: int, num_beams: int,
device: torch.device, device: torch.device,
length_penalty: Optional[float] = 1.0, length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False, do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1, num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1, num_beam_groups: Optional[int] = 1,
**kwargs,
): ):
self.max_length = max_length
self.num_beams = num_beams self.num_beams = num_beams
self.device = device self.device = device
self.length_penalty = length_penalty self.length_penalty = length_penalty
@@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer):
self._beam_hyps = [ self._beam_hyps = [
BeamHypotheses( BeamHypotheses(
num_beams=self.num_beams, num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping, early_stopping=self.do_early_stopping,
) )
@@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer):
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
) )
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
",or `group_beam_search(...)`."
)
@property @property
def is_done(self) -> bool: def is_done(self) -> bool:
return self._done.all() return self._done.all()
@@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer):
final_beam_scores: torch.FloatTensor, final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor, final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor, final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]: ) -> Tuple[torch.LongTensor]:
@@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer):
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos # prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed # shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
@@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer):
# fill with hypotheses and eos_token_id if the latter fits in # fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length: if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return UserDict( return UserDict(
{ {
@@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer):
class BeamHypotheses: class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
""" """
Initialize n-best list of hypotheses. Initialize n-best list of hypotheses.
""" """
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
self.num_beams = num_beams self.num_beams = num_beams

View File

@@ -1027,7 +1027,6 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
@@ -1063,7 +1062,6 @@ class GenerationMixin:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
@@ -1700,7 +1698,6 @@ class GenerationMixin:
>>> # instantiate beam scorer >>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer( >>> beam_scorer = BeamSearchScorer(
... batch_size=1, ... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams, ... num_beams=num_beams,
... device=model.device, ... device=model.device,
... ) ... )
@@ -1756,7 +1753,7 @@ class GenerationMixin:
assert ( assert (
num_beams * batch_size == batch_beam_size num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
@@ -1792,10 +1789,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
@@ -1861,7 +1855,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:
@@ -2086,10 +2086,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
@@ -2160,7 +2157,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:
@@ -2411,10 +2414,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
@@ -2497,7 +2497,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:

View File

@@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits return logits

View File

@@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,

View File

@@ -59,7 +59,6 @@ class BeamSearchTester:
def prepare_beam_scorer(self, **kwargs): def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer( return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size), batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams), num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device, device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty), length_penalty=kwargs.get("length_penalty", self.length_penalty),
@@ -170,9 +169,7 @@ class BeamSearchTester:
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended # max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1 max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer( beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
# update beams and append to input_ids # update beams and append to input_ids
tokens = next_tokens.clone() tokens = next_tokens.clone()
@@ -197,6 +194,7 @@ class BeamSearchTester:
output_indices, output_indices,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
@@ -225,6 +223,7 @@ class BeamSearchTester:
output_indices, output_indices,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"] sequence_scores = sequence_output["sequence_scores"]

View File

@@ -148,7 +148,6 @@ class GenerationTesterMixin:
} }
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"], num_beams=beam_kwargs["num_beams"],
device=torch_device, device=torch_device,
length_penalty=beam_kwargs["length_penalty"], length_penalty=beam_kwargs["length_penalty"],
@@ -169,7 +168,6 @@ class GenerationTesterMixin:
} }
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"], num_beams=beam_kwargs["num_beams"],
device=torch_device, device=torch_device,
length_penalty=beam_kwargs["length_penalty"], length_penalty=beam_kwargs["length_penalty"],
@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
) )
@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
diverse_beam_scorer = BeamSearchScorer( diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Beam # Beam
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
) )
@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Grouped beam search # Grouped beam search
diverse_beam_scorer = BeamSearchScorer( diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
max_length=max_length, max_length=max_length,
**model_kwargs, **model_kwargs,
) )
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())