[tests] remove tf/flax tests in /generation (#36235)
This commit is contained in:
@@ -19,7 +19,6 @@ import timeout_decorator # noqa
|
||||
from transformers import BartConfig, BartTokenizer, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -324,7 +323,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -20,7 +20,6 @@ import timeout_decorator # noqa
|
||||
from transformers import BlenderbotConfig, is_flax_available
|
||||
from transformers.testing_utils import jax_device, require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -309,7 +308,7 @@ class BlenderbotHeadTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -20,7 +20,6 @@ import timeout_decorator # noqa
|
||||
from transformers import BlenderbotSmallConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -308,7 +307,7 @@ class BlenderbotHeadTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -18,7 +18,6 @@ import numpy as np
|
||||
from transformers import BloomConfig, BloomTokenizerFast, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -169,7 +168,7 @@ class FlaxBloomModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -18,7 +18,6 @@ import numpy as np
|
||||
from transformers import AutoTokenizer, GemmaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_read_token, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -174,7 +173,7 @@ class FlaxGemmaModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxGemmaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxGemmaModel, FlaxGemmaForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -22,7 +22,6 @@ import transformers
|
||||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -209,7 +208,7 @@ class FlaxGPT2ModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -22,7 +22,6 @@ import transformers
|
||||
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -181,7 +180,7 @@ class FlaxGPTNeoModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxGPTNeoModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxGPTNeoModel, FlaxGPTNeoForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -22,7 +22,6 @@ import transformers
|
||||
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -178,7 +177,7 @@ class FlaxGPTJModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxGPTJModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxGPTJModel, FlaxGPTJForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -20,7 +20,6 @@ import numpy as np
|
||||
from transformers import LlamaConfig, is_flax_available, is_tokenizers_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -174,7 +173,7 @@ class FlaxLlamaModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxLlamaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxLlamaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxLlamaModel, FlaxLlamaForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -28,7 +28,6 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
@@ -235,7 +234,7 @@ class FlaxLongT5ModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxLongT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from transformers import MarianConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -228,7 +227,7 @@ class FlaxMarianModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (FlaxMarianModel, FlaxMarianMTModel) if is_flax_available() else ()
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from transformers import MBartConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -330,7 +329,7 @@ class MBartHeadTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxMBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxMBartModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -20,7 +20,6 @@ import numpy as np
|
||||
from transformers import MistralConfig, is_flax_available, is_tokenizers_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -185,7 +184,7 @@ class FlaxMistralModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxMistralModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxMistralModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxMistralModel, FlaxMistralForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -24,7 +24,6 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
|
||||
from ...generation.test_tf_utils import TFGenerationIntegrationTests
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@@ -244,7 +243,7 @@ class TFMistralModelTester:
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, PipelineTesterMixin, unittest.TestCase):
|
||||
class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(TFMistralModel, TFMistralForCausalLM, TFMistralForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ import timeout_decorator # noqa
|
||||
from transformers import OPTConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -203,7 +202,7 @@ class FlaxOPTModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||
class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxOPTModel, FlaxOPTForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -416,82 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_features
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_features,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(
|
||||
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
|
||||
)
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent -- the input is `input_features`, not `input_ids`
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -27,7 +27,6 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
@@ -227,7 +226,7 @@ class FlaxT5ModelTester:
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxT5ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxT5Model, FlaxT5ForConditionalGeneration) if is_flax_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
||||
|
||||
@@ -524,127 +524,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_features
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_features,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(
|
||||
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
|
||||
)
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
def _load_datasamples(num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
@@ -21,7 +21,6 @@ import transformers
|
||||
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -181,7 +180,7 @@ class FlaxXGLMModelTester:
|
||||
|
||||
@require_sentencepiece
|
||||
@require_flax
|
||||
class FlaxXGLMModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxXGLMModel, FlaxXGLMForCausalLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user