Fix CIs for PyTorch 1.13 (#20686)
* fix 1 * fix 2 * fix 3 * fix 4 Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1538,7 +1538,7 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -2738,7 +2738,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -1057,7 +1057,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
|||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -734,7 +734,8 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
pooled_output = last_hidden_state[
|
pooled_output = last_hidden_state[
|
||||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
@@ -746,7 +746,8 @@ class CLIPSegTextTransformer(nn.Module):
|
|||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
pooled_output = last_hidden_state[
|
pooled_output = last_hidden_state[
|
||||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
@@ -1401,7 +1401,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -883,7 +883,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -969,7 +969,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -1134,7 +1134,8 @@ class GroupViTTextTransformer(nn.Module):
|
|||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
pooled_output = last_hidden_state[
|
pooled_output = last_hidden_state[
|
||||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
@@ -2608,7 +2608,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -1525,7 +1525,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -1674,7 +1674,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -1069,7 +1069,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
|
|||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -1496,7 +1496,7 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
@@ -2982,7 +2982,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0] # last hidden state
|
hidden_states = outputs[0] # last hidden state
|
||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
|
||||||
|
|
||||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|||||||
Reference in New Issue
Block a user