Use torch.unique_consecutive to check same element (#13637)
We use `torch.unique` here only to check whether every elements have the same value. Therefore, we can use `torch.unique_consecutive` here. This function eliminates all but the first element from every consecutive group of equivalent elements. Like, if we apply this function to `[1, 2, 2, 1]`, it will result in `[1, 2, 1]`. As you could see, this is enough for checking whether every elements have the same value. Since `torch.unique_consecutive` do less thing, it is much more faster. On my computer, it is 25x faster on GPU and 15x faster on CPU.
This commit is contained in:
@@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
|
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||||
:, -1, :
|
:, -1, :
|
||||||
|
|||||||
@@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
|||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
|
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||||
:, -1, :
|
:, -1, :
|
||||||
|
|||||||
@@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
|
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||||
:, -1, :
|
:, -1, :
|
||||||
|
|||||||
@@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
|
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||||
:, -1, :
|
:, -1, :
|
||||||
|
|||||||
@@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
|||||||
|
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
|
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||||
:, -1, :
|
:, -1, :
|
||||||
|
|||||||
Reference in New Issue
Block a user