Remove max length beam scorer (#11378)
* removed max_len * removed max_length from BeamSearchScorer * correct max length * finish * del vim * finish & add test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -110,6 +111,7 @@ class BeamScorer(ABC):
|
|||||||
next_scores: torch.FloatTensor,
|
next_scores: torch.FloatTensor,
|
||||||
next_tokens: torch.LongTensor,
|
next_tokens: torch.LongTensor,
|
||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
|
max_length: int,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
raise NotImplementedError("This is an abstract method.")
|
raise NotImplementedError("This is an abstract method.")
|
||||||
@@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_length: int,
|
|
||||||
num_beams: int,
|
num_beams: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
length_penalty: Optional[float] = 1.0,
|
length_penalty: Optional[float] = 1.0,
|
||||||
do_early_stopping: Optional[bool] = False,
|
do_early_stopping: Optional[bool] = False,
|
||||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||||
num_beam_groups: Optional[int] = 1,
|
num_beam_groups: Optional[int] = 1,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.max_length = max_length
|
|
||||||
self.num_beams = num_beams
|
self.num_beams = num_beams
|
||||||
self.device = device
|
self.device = device
|
||||||
self.length_penalty = length_penalty
|
self.length_penalty = length_penalty
|
||||||
@@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
self._beam_hyps = [
|
self._beam_hyps = [
|
||||||
BeamHypotheses(
|
BeamHypotheses(
|
||||||
num_beams=self.num_beams,
|
num_beams=self.num_beams,
|
||||||
max_length=self.max_length,
|
|
||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
early_stopping=self.do_early_stopping,
|
early_stopping=self.do_early_stopping,
|
||||||
)
|
)
|
||||||
@@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "max_length" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
|
||||||
|
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
|
||||||
|
",or `group_beam_search(...)`."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_done(self) -> bool:
|
def is_done(self) -> bool:
|
||||||
return self._done.all()
|
return self._done.all()
|
||||||
@@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
final_beam_scores: torch.FloatTensor,
|
final_beam_scores: torch.FloatTensor,
|
||||||
final_beam_tokens: torch.LongTensor,
|
final_beam_tokens: torch.LongTensor,
|
||||||
final_beam_indices: torch.LongTensor,
|
final_beam_indices: torch.LongTensor,
|
||||||
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
@@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||||
|
|
||||||
# prepare for adding eos
|
# prepare for adding eos
|
||||||
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
@@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
# fill with hypotheses and eos_token_id if the latter fits in
|
# fill with hypotheses and eos_token_id if the latter fits in
|
||||||
for i, hypo in enumerate(best):
|
for i, hypo in enumerate(best):
|
||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
if sent_lengths[i] < self.max_length:
|
if sent_lengths[i] < max_length:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
decoded[i, sent_lengths[i]] = eos_token_id
|
||||||
return UserDict(
|
return UserDict(
|
||||||
{
|
{
|
||||||
@@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
|
|
||||||
class BeamHypotheses:
|
class BeamHypotheses:
|
||||||
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
|
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
|
||||||
"""
|
"""
|
||||||
Initialize n-best list of hypotheses.
|
Initialize n-best list of hypotheses.
|
||||||
"""
|
"""
|
||||||
self.max_length = max_length - 1 # ignoring bos_token
|
|
||||||
self.length_penalty = length_penalty
|
self.length_penalty = length_penalty
|
||||||
self.early_stopping = early_stopping
|
self.early_stopping = early_stopping
|
||||||
self.num_beams = num_beams
|
self.num_beams = num_beams
|
||||||
|
|||||||
@@ -1027,7 +1027,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=stopping_criteria.max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -1063,7 +1062,6 @@ class GenerationMixin:
|
|||||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=stopping_criteria.max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -1700,7 +1698,6 @@ class GenerationMixin:
|
|||||||
>>> # instantiate beam scorer
|
>>> # instantiate beam scorer
|
||||||
>>> beam_scorer = BeamSearchScorer(
|
>>> beam_scorer = BeamSearchScorer(
|
||||||
... batch_size=1,
|
... batch_size=1,
|
||||||
... max_length=model.config.max_length,
|
|
||||||
... num_beams=num_beams,
|
... num_beams=num_beams,
|
||||||
... device=model.device,
|
... device=model.device,
|
||||||
... )
|
... )
|
||||||
@@ -1756,7 +1753,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
num_beams * batch_size == batch_beam_size
|
num_beams * batch_size == batch_beam_size
|
||||||
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||||
|
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
beam_scores[:, 1:] = -1e9
|
beam_scores[:, 1:] = -1e9
|
||||||
@@ -1792,10 +1789,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||||
next_token_logits, cur_len=cur_len, max_length=None
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||||
@@ -1861,7 +1855,13 @@ class GenerationMixin:
|
|||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
sequence_outputs = beam_scorer.finalize(
|
sequence_outputs = beam_scorer.finalize(
|
||||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
input_ids,
|
||||||
|
beam_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
@@ -2086,10 +2086,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||||
next_token_logits, cur_len=cur_len, max_length=None
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||||
@@ -2160,7 +2157,13 @@ class GenerationMixin:
|
|||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
sequence_outputs = beam_scorer.finalize(
|
sequence_outputs = beam_scorer.finalize(
|
||||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
input_ids,
|
||||||
|
beam_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
@@ -2411,10 +2414,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||||
# cannot be generated both before and after the `F.log_softmax` operation.
|
# cannot be generated both before and after the `F.log_softmax` operation.
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
||||||
next_token_logits, cur_len=cur_len, max_length=None
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||||
vocab_size = next_token_scores.shape[-1]
|
vocab_size = next_token_scores.shape[-1]
|
||||||
|
|
||||||
@@ -2497,7 +2497,13 @@ class GenerationMixin:
|
|||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
sequence_outputs = beam_scorer.finalize(
|
sequence_outputs = beam_scorer.finalize(
|
||||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
input_ids,
|
||||||
|
beam_scores,
|
||||||
|
next_tokens,
|
||||||
|
next_indices,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
max_length=stopping_criteria.max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
|
|||||||
@@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len):
|
||||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|||||||
@@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ class BeamSearchTester:
|
|||||||
def prepare_beam_scorer(self, **kwargs):
|
def prepare_beam_scorer(self, **kwargs):
|
||||||
return BeamSearchScorer(
|
return BeamSearchScorer(
|
||||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||||
max_length=kwargs.get("max_length", self.max_length),
|
|
||||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||||
@@ -170,9 +169,7 @@ class BeamSearchTester:
|
|||||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||||
max_length = self.sequence_length + 1
|
max_length = self.sequence_length + 1
|
||||||
beam_scorer = self.prepare_beam_scorer(
|
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
|
||||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# update beams and append to input_ids
|
# update beams and append to input_ids
|
||||||
tokens = next_tokens.clone()
|
tokens = next_tokens.clone()
|
||||||
@@ -197,6 +194,7 @@ class BeamSearchTester:
|
|||||||
output_indices,
|
output_indices,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequences = sequence_output["sequences"]
|
sequences = sequence_output["sequences"]
|
||||||
@@ -225,6 +223,7 @@ class BeamSearchTester:
|
|||||||
output_indices,
|
output_indices,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
sequences = sequence_output["sequences"]
|
sequences = sequence_output["sequences"]
|
||||||
sequence_scores = sequence_output["sequence_scores"]
|
sequence_scores = sequence_output["sequence_scores"]
|
||||||
|
|||||||
@@ -148,7 +148,6 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=beam_kwargs["num_beams"],
|
num_beams=beam_kwargs["num_beams"],
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
length_penalty=beam_kwargs["length_penalty"],
|
length_penalty=beam_kwargs["length_penalty"],
|
||||||
@@ -169,7 +168,6 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=beam_kwargs["num_beams"],
|
num_beams=beam_kwargs["num_beams"],
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
length_penalty=beam_kwargs["length_penalty"],
|
length_penalty=beam_kwargs["length_penalty"],
|
||||||
@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
diverse_beam_scorer = BeamSearchScorer(
|
diverse_beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
num_beam_hyps_to_keep=num_return_sequences,
|
num_beam_hyps_to_keep=num_return_sequences,
|
||||||
@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# Beam
|
# Beam
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# Grouped beam search
|
# Grouped beam search
|
||||||
diverse_beam_scorer = BeamSearchScorer(
|
diverse_beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_length=max_length,
|
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
num_beam_hyps_to_keep=num_return_sequences,
|
num_beam_hyps_to_keep=num_return_sequences,
|
||||||
@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||||
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||||
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
num_beams = 3
|
||||||
|
|
||||||
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
input_ids = input_ids.expand(num_beams, -1)
|
||||||
|
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||||
|
|
||||||
|
stopping_criteria_max_length = 18
|
||||||
|
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||||
|
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
beam_scorer = BeamSearchScorer(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
device=torch_device,
|
||||||
|
max_length=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids = bart_model.beam_search(
|
||||||
|
input_ids,
|
||||||
|
num_beams=num_beams,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
beam_scorer=beam_scorer,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
beam_scorer_no_max_len = BeamSearchScorer(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_ids_no_max_len = bart_model.beam_search(
|
||||||
|
input_ids,
|
||||||
|
num_beams=num_beams,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
beam_scorer=beam_scorer_no_max_len,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# BeamSearchScorer max_length should not influence "real" max_length
|
||||||
|
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||||
|
|||||||
Reference in New Issue
Block a user