Fix CIs for PyTorch 1.13 (#20686)

* fix 1

* fix 2

* fix 3

* fix 4

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-12-08 18:51:54 +01:00
committed by GitHub
parent bcc069ddb8
commit e3cc4487fe
15 changed files with 18 additions and 15 deletions

View File

@@ -2982,7 +2982,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
)
hidden_states = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")