better naming
This commit is contained in:
@@ -943,7 +943,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask):
|
def prepare_inputs_for_generation_bart(input_ids, past, decoder_input_ids, attention_mask):
|
||||||
if past is None: # first step
|
if past is None: # first step
|
||||||
encoder_outputs, decoder_cached_states = None, None
|
encoder_outputs, decoder_cached_states = None, None
|
||||||
else:
|
else:
|
||||||
@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_1(
|
def generate_bart(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
@@ -1113,7 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
||||||
for step in range(max_length + 1):
|
for step in range(max_length + 1):
|
||||||
decoder_input_ids = prev_output_tokens.clone()
|
decoder_input_ids = prev_output_tokens.clone()
|
||||||
model_inputs = self.prepare_inputs_for_generation_1(
|
model_inputs = self.prepare_inputs_for_generation_bart(
|
||||||
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
||||||
)
|
)
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
|
|||||||
@@ -411,7 +411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"Error no file named {} found in directory {} or `from_tf` set to False".format(
|
"Error no file named {} found in directory {} or `from_tf` set to False".format(
|
||||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index",],
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -816,7 +816,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
effective_batch_size * num_beams, input_ids_len
|
effective_batch_size * num_beams, input_ids_len
|
||||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
|
||||||
# TODO (PVP): check eos_token_id
|
|
||||||
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
||||||
is_encoder_decoder = (
|
is_encoder_decoder = (
|
||||||
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
||||||
@@ -829,7 +828,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
|
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
|
|||||||
@@ -427,7 +427,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
||||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||||
extra_len = 20
|
extra_len = 20
|
||||||
gen_tokens_bart = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
|
gen_tokens_bart = hf.generate_bart(tokens, num_beams=3, max_length=extra_len,) # repetition_penalty=10.,
|
||||||
gen_tokens = hf.generate(
|
gen_tokens = hf.generate(
|
||||||
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
|
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
|
||||||
) # repetition_penalty=10.,
|
) # repetition_penalty=10.,
|
||||||
|
|||||||
Reference in New Issue
Block a user