re-add eos token to get good bart results
This commit is contained in:
@@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
no_repeat_ngram_size=None,
|
no_repeat_ngram_size=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
decoder_start_token_id=None,
|
||||||
):
|
):
|
||||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
and beam-search.
|
and beam-search.
|
||||||
@@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
num_return_sequences = (
|
num_return_sequences = (
|
||||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
)
|
)
|
||||||
|
# TODO: think about how to make this cleaner
|
||||||
|
decoder_start_token_id = (
|
||||||
|
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||||
@@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
assert (eos_token_ids is None) or (
|
assert (eos_token_ids is None) or (
|
||||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||||
|
assert (
|
||||||
|
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||||
|
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||||
assert (
|
assert (
|
||||||
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
||||||
@@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
bos_token_id, # TODO: wait for results of Bart CNN summarization
|
decoder_start_token_id, # TODO: see whether this is the best result
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -432,7 +432,11 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||||
extra_len = 20
|
extra_len = 20
|
||||||
gen_tokens = hf.generate(
|
gen_tokens = hf.generate(
|
||||||
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
|
tokens,
|
||||||
|
num_beams=4,
|
||||||
|
max_length=extra_len + 2,
|
||||||
|
do_sample=False,
|
||||||
|
decoder_start_token_id=hf.config.eos_token_id,
|
||||||
) # repetition_penalty=10.,
|
) # repetition_penalty=10.,
|
||||||
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
|
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
|
||||||
generated = [tok.decode(g,) for g in gen_tokens]
|
generated = [tok.decode(g,) for g in gen_tokens]
|
||||||
@@ -477,6 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
|
decoder_start_token_id=hf.config.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = [
|
decoded = [
|
||||||
|
|||||||
Reference in New Issue
Block a user