Merge pull request #3225 from patrickvonplaten/finalize_merge_bart_generate_into_default_generate

Complete merge Seq-2-Seq generation into default generation
This commit is contained in:
Thomas Wolf
2020-03-14 15:08:59 +01:00
committed by GitHub
3 changed files with 25 additions and 8 deletions

View File

@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.eos_token_id = 2
self.eos_token_ids = [2]
self.pad_token_id = 1
self.bos_token_id = 0
torch.manual_seed(0)
@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[self.eos_token_id],
eos_token_ids=[2],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
@@ -438,7 +438,11 @@ class BartModelIntegrationTest(unittest.TestCase):
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
@@ -483,6 +487,7 @@ class BartModelIntegrationTest(unittest.TestCase):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0],
)
decoded = [