finalize generation merge
This commit is contained in:
@@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_ids)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
|
||||
return scores
|
||||
|
||||
@staticmethod
|
||||
@@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
x = outputs[0] # last hidden state
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
||||
|
||||
Reference in New Issue
Block a user