From d880a5fbde719775455ebd21884e86370a99fb4f Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Sat, 7 Mar 2020 10:55:23 +0100 Subject: [PATCH] finalized PR --- src/transformers/modeling_utils.py | 16 +++++++++------- tests/test_modeling_bart.py | 8 ++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7da96897f8..a1a0306684 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" # create attention mask if necessary - # TODO (PVP): this should later be handled by the forward fn() in each model + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): attention_mask = input_ids.ne(pad_token_id).long() elif attention_mask is None: @@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if unfinished_sents.max() == 0: break - # extend attention_mask for new generated input + # extend attention_mask for new generated input if only decoder if self.config.is_encoder_decoder is False: attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) cur_len = cur_len + 1 @@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_token_logits = next_token_logits / temperature scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here? + if ( + self.config.is_encoder_decoder + ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached @@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if past: past = self._reorder_cache(past, beam_idx) - # extend attention_mask for new generated input + # extend attention_mask for new generated input if only decoder if self.config.is_encoder_decoder is False: attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) # update current length @@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) if self.config.is_encoder_decoder: - # do not return first token + # do not return first token return decoded[:, 1:] return decoded diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index af3d1567d0..5c361d0499 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -453,9 +453,7 @@ class BartModelIntegrationTest(unittest.TestCase): EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway." dct = tok.batch_encode_plus( - # [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], - [IRAN_ARTICLE, ARTICLE_SUBWAY], - # [FRANCE_ARTICLE, SHORTER_ARTICLE], + [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], max_length=1024, pad_to_max_length=True, return_tensors="pt", @@ -482,9 +480,7 @@ class BartModelIntegrationTest(unittest.TestCase): ] self.assertListEqual( - # [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], - [EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], - # [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER], + [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], decoded, ) # TODO(SS): run fairseq again with num_beams=2, min_len=20.