diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d877034730..d62740274c 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1538,7 +1538,7 @@ class BartForSequenceClassification(BartPretrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ed7539b7e2..9b3b41ec4e 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2738,7 +2738,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6002cc7be7..4c66d47b89 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -1057,7 +1057,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1 + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index c2529c4592..b4b7e7b453 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -734,7 +734,8 @@ class CLIPTextTransformer(nn.Module): # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), ] if not return_dict: diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index c4dd399183..dbefa1edde 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -746,7 +746,8 @@ class CLIPSegTextTransformer(nn.Module): # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), ] if not return_dict: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index f8b8290a20..6b80e565e0 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1401,7 +1401,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 0d03d227e2..a39a1b3d1a 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -883,7 +883,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index c20ebcb77d..32a714684a 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -969,7 +969,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 98a8d9be58..e027c6b90c 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1134,7 +1134,8 @@ class GroupViTTextTransformer(nn.Module): # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), ] if not return_dict: diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 388e94641d..0e7ae9d402 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2608,7 +2608,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 185ccb277f..803a532bd9 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1525,7 +1525,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 08da096992..46315b0965 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1674,7 +1674,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index b37e0cca39..ec090c932e 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -1069,7 +1069,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 7f98254756..79cf5c855c 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1496,7 +1496,7 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 9e2154901a..a7909fad15 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2982,7 +2982,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ) hidden_states = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_id) + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.")