From 2acfe639640782b42608b7a26808b7bf9b03438a Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 6 Mar 2020 22:19:01 +0100 Subject: [PATCH] best current version and make style --- src/transformers/configuration_t5.py | 3 +- src/transformers/modeling_bart.py | 6 ++- src/transformers/modeling_utils.py | 56 +++++++++++++++------------- tests/test_modeling_bart.py | 16 ++++---- 4 files changed, 45 insertions(+), 36 deletions(-) diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index 28772f85b7..767bec762d 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -79,8 +79,7 @@ class T5Config(PretrainedConfig): **kwargs ): super().__init__( - is_encoder_decoder=is_encoder_decoder, - **kwargs, + is_encoder_decoder=is_encoder_decoder, **kwargs, ) self.vocab_size = vocab_size self.n_positions = n_positions diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 44d679a312..bb32b1c96c 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -942,7 +942,9 @@ class BartForConditionalGeneration(PretrainedBartModel): return outputs def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask): - assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape) + assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format( + attention_mask.shape, encoder_inputs.shape + ) if past is None: # first step encoder_outputs, decoder_cached_states = None, None else: @@ -954,7 +956,7 @@ class BartForConditionalGeneration(PretrainedBartModel): "encoder_outputs": encoder_outputs, "decoder_cached_states": decoder_cached_states, "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask + "attention_mask": attention_mask, } def prepare_scores_for_generation(self, scores, cur_len, max_length): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b1a6b8f366..7da96897f8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -19,8 +19,6 @@ import logging import os import typing -import ipdb - import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -829,7 +827,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.shape[-1] input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) - attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) + attention_mask = attention_mask.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_ids_len + ) input_ids = input_ids.contiguous().view( effective_batch_size * num_beams, input_ids_len @@ -846,9 +846,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - eos_token_id, -# bos_token_id, - # eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case + eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this dtype=torch.long, device=next(self.parameters()).device, ) @@ -919,7 +917,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): eos_token_ids, batch_size, encoder_inputs, - attention_mask + attention_mask, ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -930,7 +928,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): past = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask) + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask + ) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -948,16 +948,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) for batch_idx in range(batch_size): - next_token_logits[ - batch_idx, banned_tokens[batch_idx] - ] = -float('inf') + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") # set eos token prob to zero if min_length is not reached if eos_token_ids is not None and cur_len < min_length: for eos_token_id in eos_token_ids: - next_token_logits[ - :, eos_token_id - ] = -float('inf') + next_token_logits[:, eos_token_id] = -float("inf") if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -995,7 +991,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # extend attention_mask for new generated input if self.config.is_encoder_decoder is False: - attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1) + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 + ) cur_len = cur_len + 1 @@ -1041,8 +1039,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # generated hypotheses generated_hyps = [ - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size) -# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size) + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) + for _ in range(batch_size) ] # scores for each sentence in the beam @@ -1060,7 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask) + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask + ) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) @@ -1084,17 +1084,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # set eos token prob to zero if min_length is not reached if eos_token_ids is not None and cur_len < min_length: for eos_token_id in eos_token_ids: - scores[:, eos_token_id] = -float('inf') + scores[:, eos_token_id] = -float("inf") if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams num_batch_hypotheses = batch_size * num_beams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len) + banned_batch_tokens = calc_banned_tokens( + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len + ) for i, banned_tokens in enumerate(banned_batch_tokens): - scores[i, banned_tokens] = -float('inf') + scores[i, banned_tokens] = -float("inf") - assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size)) + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( + scores.shape, (batch_size * num_beams, vocab_size) + ) if do_sample: _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) @@ -1203,7 +1207,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # extend attention_mask for new generated input if self.config.is_encoder_decoder is False: - attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1) + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 + ) # update current length cur_len = cur_len + 1 @@ -1278,7 +1284,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): device=next(self.parameters()).device, ) assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" - scores[:, all_but_token_ids_mask] = -float('inf') + scores[:, all_but_token_ids_mask] = -float("inf") @staticmethod def _reorder_cache(past, beam_idx): @@ -1311,7 +1317,7 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) def _get_generated_ngrams(hypo_idx): # Before decoding the next token, prevent decoding of ngrams that have already appeared start_idx = cur_len + 1 - no_repeat_ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx: cur_len].tolist()) + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) return generated_ngrams[hypo_idx].get(ngram_idx, []) banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index b244fa1f96..af3d1567d0 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -104,7 +104,9 @@ def prepare_bart_inputs_dict( @require_torch class BARTModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () + ) all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True # TODO(SS): fix the below in a separate PR @@ -451,9 +453,9 @@ class BartModelIntegrationTest(unittest.TestCase): EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway." dct = tok.batch_encode_plus( -# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], + # [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], [IRAN_ARTICLE, ARTICLE_SUBWAY], -# [FRANCE_ARTICLE, SHORTER_ARTICLE], + # [FRANCE_ARTICLE, SHORTER_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt", @@ -472,17 +474,17 @@ class BartModelIntegrationTest(unittest.TestCase): min_length=min_length + 1, no_repeat_ngram_size=3, do_sample=False, - early_stopping=True + early_stopping=True, ) - + decoded = [ tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch ] self.assertListEqual( -# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], + # [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], [EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], -# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER], + # [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER], decoded, ) # TODO(SS): run fairseq again with num_beams=2, min_len=20.