Clean special token init in modeling_....py (#3264)

* make style

* fix conflicts
This commit is contained in:
Patrick von Platen
2020-03-20 21:41:04 +01:00
committed by GitHub
parent 8becb73293
commit 95e00d0808
22 changed files with 117 additions and 115 deletions

View File

@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
@staticmethod
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]