[tests] remove tf/flax tests in /generation (#36235)

This commit is contained in:
Joao Gante
2025-02-17 14:59:22 +00:00
committed by GitHub
parent c877c9fa5b
commit 55493f1390
26 changed files with 428 additions and 2663 deletions

View File

@@ -19,7 +19,6 @@ import timeout_decorator # noqa
from transformers import BartConfig, BartTokenizer, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -324,7 +323,7 @@ class BartHeadTests(unittest.TestCase):
@require_flax
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (
(

View File

@@ -20,7 +20,6 @@ import timeout_decorator # noqa
from transformers import BlenderbotConfig, is_flax_available
from transformers.testing_utils import jax_device, require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -309,7 +308,7 @@ class BlenderbotHeadTests(unittest.TestCase):
@require_flax
class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (
(

View File

@@ -20,7 +20,6 @@ import timeout_decorator # noqa
from transformers import BlenderbotSmallConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -308,7 +307,7 @@ class BlenderbotHeadTests(unittest.TestCase):
@require_flax
class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (
(

View File

@@ -18,7 +18,6 @@ import numpy as np
from transformers import BloomConfig, BloomTokenizerFast, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -169,7 +168,7 @@ class FlaxBloomModelTester:
@require_flax
class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -18,7 +18,6 @@ import numpy as np
from transformers import AutoTokenizer, GemmaConfig, is_flax_available
from transformers.testing_utils import require_flax, require_read_token, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -174,7 +173,7 @@ class FlaxGemmaModelTester:
@require_flax
class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxGemmaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGemmaModel, FlaxGemmaForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -22,7 +22,6 @@ import transformers
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -209,7 +208,7 @@ class FlaxGPT2ModelTester:
@require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
def setUp(self):

View File

@@ -22,7 +22,6 @@ import transformers
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@@ -181,7 +180,7 @@ class FlaxGPTNeoModelTester:
@require_flax
class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxGPTNeoModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPTNeoModel, FlaxGPTNeoForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -22,7 +22,6 @@ import transformers
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@@ -178,7 +177,7 @@ class FlaxGPTJModelTester:
@require_flax
class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxGPTJModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPTJModel, FlaxGPTJForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -20,7 +20,6 @@ import numpy as np
from transformers import LlamaConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -174,7 +173,7 @@ class FlaxLlamaModelTester:
@require_flax
class FlaxLlamaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxLlamaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxLlamaModel, FlaxLlamaForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -28,7 +28,6 @@ from transformers.testing_utils import (
slow,
)
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -235,7 +234,7 @@ class FlaxLongT5ModelTester:
@require_flax
class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxLongT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
is_encoder_decoder = True

View File

@@ -21,7 +21,6 @@ from transformers import MarianConfig, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -228,7 +227,7 @@ class FlaxMarianModelTester:
@require_flax
class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (FlaxMarianModel, FlaxMarianMTModel) if is_flax_available() else ()

View File

@@ -21,7 +21,6 @@ from transformers import MBartConfig, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -330,7 +329,7 @@ class MBartHeadTests(unittest.TestCase):
@require_flax
class FlaxMBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxMBartModelTest(FlaxModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (
(

View File

@@ -20,7 +20,6 @@ import numpy as np
from transformers import MistralConfig, is_flax_available, is_tokenizers_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -185,7 +184,7 @@ class FlaxMistralModelTester:
@require_flax
class FlaxMistralModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxMistralModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxMistralModel, FlaxMistralForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -24,7 +24,6 @@ from transformers.testing_utils import (
slow,
)
from ...generation.test_tf_utils import TFGenerationIntegrationTests
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -244,7 +243,7 @@ class TFMistralModelTester:
@require_tf
class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, PipelineTesterMixin, unittest.TestCase):
class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(TFMistralModel, TFMistralForCausalLM, TFMistralForSequenceClassification) if is_tf_available() else ()
)

View File

@@ -19,7 +19,6 @@ import timeout_decorator # noqa
from transformers import OPTConfig, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -203,7 +202,7 @@ class FlaxOPTModelTester:
@require_flax
class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxOPTModel, FlaxOPTForCausalLM) if is_flax_available() else ()
def setUp(self):

View File

@@ -416,82 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
def test_generate_without_input_ids(self):
pass
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features`
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_features = inputs_dict.get("input_features", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
if config.bos_token_id is None:
# if bos token id is not defined model needs input_features
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_features, do_sample=True))
with self.assertRaises(ValueError):
# generating multiple sequences when no beam search generation
# is not allowed as it would always generate the same sequences
model.generate(input_features, do_sample=False, num_return_sequences=2)
# num_return_sequences > 1, sample
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
output_tokens = model.generate(
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_features.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features`
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_features = inputs_dict.get("input_features", None)
for model_class in self.all_generative_model_classes:
model = model_class(config)
if config.bos_token_id is None:
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
with self.assertRaises(ValueError):
# generating more sequences than having beams leads is not possible
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
# num_return_sequences > 1, sample
self._check_generated_ids(
model.generate(
input_features,
do_sample=True,
num_beams=2,
num_return_sequences=2,
)
)
# num_return_sequences > 1, greedy
self._check_generated_ids(
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
)
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
output_tokens = model.generate(
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_features.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
# overwritten from parent -- the input is `input_features`, not `input_ids`
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -27,7 +27,6 @@ from transformers.testing_utils import (
slow,
)
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
@@ -227,7 +226,7 @@ class FlaxT5ModelTester:
@require_flax
class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxT5Model, FlaxT5ForConditionalGeneration) if is_flax_available() else ()
is_encoder_decoder = True

View File

@@ -524,127 +524,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
)
def test_generate_without_input_ids(self):
pass
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features`
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_features = inputs_dict.get("input_features", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
if config.bos_token_id is None:
# if bos token id is not defined model needs input_features
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_features, do_sample=True))
with self.assertRaises(ValueError):
# generating multiple sequences when no beam search generation
# is not allowed as it would always generate the same sequences
model.generate(input_features, do_sample=False, num_return_sequences=2)
# num_return_sequences > 1, sample
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
output_tokens = model.generate(
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_features.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features`
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_features = inputs_dict.get("input_features", None)
for model_class in self.all_generative_model_classes:
model = model_class(config)
if config.bos_token_id is None:
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
with self.assertRaises(ValueError):
# generating more sequences than having beams leads is not possible
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
# num_return_sequences > 1, sample
self._check_generated_ids(
model.generate(
input_features,
do_sample=True,
num_beams=2,
num_return_sequences=2,
)
)
# num_return_sequences > 1, greedy
self._check_generated_ids(
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
)
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
output_tokens = model.generate(
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_features.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_generate_with_prompt_ids_and_task_and_language(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = TFWhisperForConditionalGeneration(config)
input_features = input_dict["input_features"]
prompt_ids = np.arange(5)
language = "<|de|>"
task = "translate"
lang_id = 6
task_id = 7
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
model.generation_config.__setattr__("task_to_id", {task: task_id})
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
lang_id,
task_id,
]
for row in output.numpy().tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = TFWhisperForConditionalGeneration(config)
input_features = input_dict["input_features"]
prompt_ids = np.asarray(range(5))
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
output = model.generate(
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
)
expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
*[token for _rank, token in forced_decoder_ids],
]
for row in output.numpy().tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
def _load_datasamples(num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

View File

@@ -21,7 +21,6 @@ import transformers
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
@@ -181,7 +180,7 @@ class FlaxXGLMModelTester:
@require_sentencepiece
@require_flax
class FlaxXGLMModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxXGLMModel, FlaxXGLMForCausalLM) if is_flax_available() else ()
def setUp(self):