Generate: add Bloom fixes for contrastive search (#20213)

This commit is contained in:
Joao Gante
2022-11-14 18:34:11 +00:00
committed by GitHub
parent fda125638f
commit 938cb04789
3 changed files with 72 additions and 27 deletions

View File

@@ -1411,9 +1411,8 @@ class GenerationTesterMixin:
# check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
@@ -1434,9 +1433,8 @@ class GenerationTesterMixin:
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# enable cache