Generate: add Bloom fixes for contrastive search (#20213)
This commit is contained in:
@@ -1411,9 +1411,8 @@ class GenerationTesterMixin:
|
||||
# check `generate()` and `contrastive_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
@@ -1434,9 +1433,8 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
|
||||
# enable cache
|
||||
|
||||
Reference in New Issue
Block a user