From 7cba11fb9b5a86769f6613614d32f8905455eb64 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 5 Mar 2020 15:48:00 +0100 Subject: [PATCH] better naming --- src/transformers/modeling_bart.py | 6 +++--- src/transformers/modeling_utils.py | 5 ++--- tests/test_modeling_bart.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index dff240c809..c183ede558 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -943,7 +943,7 @@ class BartForConditionalGeneration(PretrainedBartModel): return outputs @staticmethod - def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask): + def prepare_inputs_for_generation_bart(input_ids, past, decoder_input_ids, attention_mask): if past is None: # first step encoder_outputs, decoder_cached_states = None, None else: @@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel): return self.lm_head @torch.no_grad() - def generate_1( + def generate_bart( self, input_ids, attention_mask=None, @@ -1113,7 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel): self.model.decoder.generation_mode = True # tells decoder not to use causal mask for step in range(max_length + 1): decoder_input_ids = prev_output_tokens.clone() - model_inputs = self.prepare_inputs_for_generation_1( + model_inputs = self.prepare_inputs_for_generation_bart( input_ids, decoder_cache, decoder_input_ids, attention_mask, ) outputs = self(**model_inputs) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 51f3b15f68..ad6c29b8bf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -411,7 +411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False".format( - [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index",], + [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path, ) ) @@ -816,7 +816,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) - # TODO (PVP): check eos_token_id # TODO (PVP): probably not the best way to check whether model is encoder decoder is_encoder_decoder = ( hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder") @@ -829,7 +828,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - # eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no? + # eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case bos_token_id, dtype=torch.long, device=next(self.parameters()).device, diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 34298d8134..c248c1e73d 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -427,7 +427,7 @@ class BartModelIntegrationTest(unittest.TestCase): text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian" tokens = tok.encode(text, return_tensors="pt").to(torch_device) extra_len = 20 - gen_tokens_bart = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10., + gen_tokens_bart = hf.generate_bart(tokens, num_beams=3, max_length=extra_len,) # repetition_penalty=10., gen_tokens = hf.generate( tokens, num_beams=4, max_length=extra_len + 2, do_sample=False ) # repetition_penalty=10.,