Generation tests: update imagegpt input name, remove unused functions (#33663)

This commit is contained in:
Joao Gante
2024-09-24 16:40:48 +01:00
committed by GitHub
parent 6f7d750b73
commit a7734238ff
18 changed files with 23 additions and 656 deletions

View File

@@ -22,7 +22,7 @@ from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
@@ -116,20 +116,6 @@ class FlaxXGLMModelTester:
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, attention_mask = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)

View File

@@ -29,7 +29,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -125,26 +125,6 @@ class XGLMModelTester:
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
head_mask,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_xglm_model(self, config, input_ids, input_mask, head_mask, *args):
model = XGLMModel(config=config)
model.to(torch_device)