Generate: add test for left-padding support (#22322)
This commit is contained in:
@@ -1497,6 +1497,54 @@ class GenerationTesterMixin:
|
|||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
|
# TODO (joao): this test is actually not slow :) However, it is not passing in some models (e.g. GPTNeoX) and the
|
||||||
|
# fix for some models is quite lengthy. Being slow means it doesn't block our push CI while we fix it.
|
||||||
|
@slow
|
||||||
|
def test_left_padding_compatibility(self):
|
||||||
|
# The check done in this test is fairly difficult -- depending on the model architecture, passing the right
|
||||||
|
# position index for the position embeddings can still result in a different output, due to numerical masking.
|
||||||
|
# On the other hand, for some types of position embeddings, an incorrect position index can have a minimal
|
||||||
|
# impact on the output.
|
||||||
|
# There are two tricks employed to check whether left-padding compatibility is in place:
|
||||||
|
# 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the
|
||||||
|
# padding size to 1.
|
||||||
|
# 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check
|
||||||
|
# multiple times with random inputs, and it has to pass with all of them.
|
||||||
|
# NOTE: because of 2), there is some chance of false positives in this test.
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, _, _, _ = self._get_input_ids_and_config()
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
continue # skip for encoder-decoder models -- they don't need left-padding compatibility
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
|
no_failures = True
|
||||||
|
for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness
|
||||||
|
_, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||||
|
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
if "position_ids" in signature:
|
||||||
|
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
model_kwargs["position_ids"] = position_ids
|
||||||
|
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
|
||||||
|
pad_size = (input_ids.shape[0], 1)
|
||||||
|
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
|
||||||
|
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||||
|
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||||
|
model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask}
|
||||||
|
if "position_ids" in signature:
|
||||||
|
position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1
|
||||||
|
position_ids.masked_fill_(padded_attention_mask == 0, 1)
|
||||||
|
model_kwargs["position_ids"] = position_ids
|
||||||
|
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
if not torch.allclose(next_logits_wo_padding, next_logits_with_padding):
|
||||||
|
no_failures = False
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(no_failures)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import unittest
|
|||||||
from transformers import GPTNeoXConfig, is_torch_available
|
from transformers import GPTNeoXConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
@@ -186,7 +187,7 @@ class GPTNeoXModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTNeoXModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
|
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
|
|||||||
Reference in New Issue
Block a user