[CI] green llama tests (#37244)
* green llama tests * use cleanup instead * better test comment; cleanup upgrade * better test comment; cleanup upgrade
This commit is contained in:
@@ -2285,6 +2285,14 @@ class GenerationTesterMixin:
|
||||
inputs_dict[input_name] = input_data
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
|
||||
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
|
||||
# resulting in a mask with holes (not supported properly by FA2).
|
||||
if attn_implementation == "flash_attention_2":
|
||||
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
|
||||
if input_name in inputs_dict:
|
||||
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
||||
@@ -2339,8 +2347,6 @@ class GenerationTesterMixin:
|
||||
@slow
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
|
||||
# check whether we still need the overwrites
|
||||
self._test_attention_implementation("flash_attention_2")
|
||||
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
@@ -3974,7 +3980,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# TODO: We need to raise a warning in case the cache is not set correctly
|
||||
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
||||
# past_key_values = StaticCache(
|
||||
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||
# config=model.config, max_batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||
# )
|
||||
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
@@ -3982,7 +3988,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
layer_device_map = {0: 0, 1: 1}
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_batch_size=1,
|
||||
max_cache_len=30,
|
||||
device=torch_device,
|
||||
dtype=model.dtype,
|
||||
@@ -4183,7 +4189,11 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
batch_size = 2
|
||||
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
|
||||
static_cache = StaticCache(
|
||||
config=config, batch_size=batch_size, max_cache_len=max_cache_len, device=torch_device, dtype=torch.float32
|
||||
config=config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
|
||||
Reference in New Issue
Block a user