Generate: remove most decoder-only LLMs prepare_inputs_for_generation (#33870)
This commit is contained in:
@@ -3000,7 +3000,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
|
||||
continue
|
||||
model = model_class(config)
|
||||
@@ -3047,7 +3047,10 @@ class ModelTesterMixin:
|
||||
**inputs,
|
||||
max_new_tokens=2,
|
||||
)
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
# NOTE: this test changes the order of FP ops, there may be tiny differences in the output
|
||||
number_of_different_tokens = (out_ids != out_embeds).sum()
|
||||
max_differences = int(out_ids.shape[0] * out_ids.shape[1] * 0.1)
|
||||
self.assertTrue(number_of_different_tokens <= max_differences) # accept up to 10% mismatch
|
||||
|
||||
@require_non_xpu
|
||||
@require_torch_multi_gpu
|
||||
|
||||
Reference in New Issue
Block a user