VLMs: enable generation tests - last batch (#34484)

* add tests for 3 more vlms

* fix fuyu back

* skip test
This commit is contained in:
Raushan Turganbay
2024-11-21 11:00:22 +01:00
committed by GitHub
parent 40821a2478
commit 28fb02fc05
6 changed files with 129 additions and 9 deletions

View File

@@ -17,12 +17,15 @@
import io
import unittest
import pytest
import requests
from parameterized import parameterized
from transformers import FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -263,8 +266,9 @@ class FuyuModelTester:
@require_torch
class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
all_generative_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
)
@@ -296,6 +300,16 @@ class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_matches_greedy_search(self):
pass
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_sample(self):
pass
# TODO: Fix me (once this model gets more usage)
@unittest.skip(reason="Does not work on the tiny model.")
def test_disk_offload_bin(self):