finalize generation merge

This commit is contained in:
Patrick von Platen
2020-03-11 11:53:36 +01:00
parent 1ba21f96ca
commit a332cc9f7f
4 changed files with 10 additions and 13 deletions

View File

@@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(scores, self.config.eos_token_ids)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
return scores
@staticmethod
@@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]