VLMs: enable generation tests (#33533)
* add tests * fix whisper * update * nit * add qwen2-vl * more updates! * better this way * fix this one * fix more tests * fix final tests, hope so * fix led * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * pr comments * not pass pixels and extra for low-mem tests, very flaky because of visio tower --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e40bb4845e
commit
d7975a5874
@@ -1154,7 +1154,7 @@ class GenerationMixin:
|
|||||||
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
|
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.config.vocab_size == assistant_model.config.vocab_size:
|
if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
|
||||||
raise ValueError("Make sure the main and assistant model use the same tokenizer")
|
raise ValueError("Make sure the main and assistant model use the same tokenizer")
|
||||||
|
|
||||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||||
@@ -1476,7 +1476,7 @@ class GenerationMixin:
|
|||||||
layer_device_map = get_layer_device_map(execution_device_map)
|
layer_device_map = get_layer_device_map(execution_device_map)
|
||||||
|
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config if hasattr(self.config, "text_config") else self.config,
|
"config": self.config.get_text_config(),
|
||||||
"max_batch_size": batch_size,
|
"max_batch_size": batch_size,
|
||||||
"max_cache_len": max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
"device": device,
|
"device": device,
|
||||||
|
|||||||
@@ -45,6 +45,74 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "PaliGemmaConfig"
|
_CONFIG_FOR_DOC = "PaliGemmaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
# But Paligemma has no causal mask on prefix
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_training: bool,
|
||||||
|
token_type_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
is_training (`bool`):
|
||||||
|
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||||
|
if sequence_length != 1:
|
||||||
|
if is_training:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
else:
|
||||||
|
causal_mask = torch.zeros_like(causal_mask)
|
||||||
|
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||||
|
if is_training:
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||||
|
)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -285,7 +353,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||||
):
|
):
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
dtype = inputs_embeds.dtype
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = inputs_embeds.shape[1]
|
sequence_length = inputs_embeds.shape[1]
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
@@ -299,19 +367,19 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
causal_mask = attention_mask
|
return attention_mask
|
||||||
else:
|
|
||||||
causal_mask = torch.full(
|
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
|
||||||
if sequence_length != 1:
|
|
||||||
if is_training:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
else:
|
|
||||||
causal_mask = torch.zeros_like(causal_mask)
|
|
||||||
|
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||||
|
)
|
||||||
|
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||||
|
if sequence_length != 1:
|
||||||
|
if is_training:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
else:
|
||||||
|
causal_mask = torch.zeros_like(causal_mask)
|
||||||
|
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
@@ -420,7 +488,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
image_features = image_features / (self.config.hidden_size**0.5)
|
image_features = image_features / (self.config.hidden_size**0.5)
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||||
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -508,11 +577,38 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
use_cache=use_cache,
|
||||||
num_logits_to_keep=num_logits_to_keep,
|
num_logits_to_keep=num_logits_to_keep,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
dtype = self.get_output_embeddings().weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
is_training = token_type_ids is not None and kwargs.get("labels", None) is not None
|
||||||
|
|
||||||
|
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_training=is_training,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs["token_type_ids"] = token_type_ids
|
model_inputs["token_type_ids"] = token_type_ids
|
||||||
|
|
||||||
# position_ids in Paligemma are 1-indexed
|
# position_ids in Paligemma are 1-indexed
|
||||||
|
|||||||
@@ -1070,7 +1070,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
||||||
)
|
)
|
||||||
self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
|
self.merger = PatchMerger(
|
||||||
|
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
||||||
|
)
|
||||||
|
|
||||||
def get_dtype(self) -> torch.dtype:
|
def get_dtype(self) -> torch.dtype:
|
||||||
return self.blocks[0].mlp.fc2.weight.dtype
|
return self.blocks[0].mlp.fc2.weight.dtype
|
||||||
|
|||||||
@@ -98,10 +98,22 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict[self.input_name]
|
# TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids`
|
||||||
|
input_ids = inputs_dict.pop(self.input_name)[:batch_size, :]
|
||||||
|
inputs_dict.pop("attention_mask", None)
|
||||||
|
|
||||||
input_ids = input_ids[:batch_size]
|
# we don't want encoder-decoder models to start from filled decoder ids
|
||||||
|
inputs_dict.pop("decoder_input_ids", None)
|
||||||
|
inputs_dict.pop("decoder_attention_mask", None)
|
||||||
|
|
||||||
|
# we'll set cache use in each test differently
|
||||||
|
inputs_dict.pop("use_cache", None)
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
k: v[:batch_size, ...]
|
||||||
|
for k, v in inputs_dict.items()
|
||||||
|
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
if isinstance(config.eos_token_id, int):
|
if isinstance(config.eos_token_id, int):
|
||||||
@@ -118,7 +130,7 @@ class GenerationTesterMixin:
|
|||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
config.forced_eos_token_id = None
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
logits_processor_kwargs = {
|
logits_processor_kwargs = {
|
||||||
@@ -191,6 +203,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -213,6 +226,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -222,6 +236,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
@@ -247,6 +262,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -256,6 +272,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
@@ -279,6 +296,7 @@ class GenerationTesterMixin:
|
|||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -288,6 +306,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
@@ -312,6 +331,7 @@ class GenerationTesterMixin:
|
|||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -321,6 +341,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
@@ -344,6 +365,7 @@ class GenerationTesterMixin:
|
|||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -353,6 +375,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
constraints,
|
constraints,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
@@ -378,6 +401,7 @@ class GenerationTesterMixin:
|
|||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -387,6 +411,7 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
inputs_dict,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -415,6 +440,7 @@ class GenerationTesterMixin:
|
|||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
**contrastive_search_kwargs,
|
**contrastive_search_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
@@ -422,10 +448,12 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate(self):
|
def test_greedy_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(model=model, input_ids=input_ids, attention_mask=attention_mask)
|
output_generate = self._greedy_generate(
|
||||||
|
model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -435,13 +463,14 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -466,7 +495,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
@@ -479,6 +508,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -497,13 +527,14 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -515,13 +546,14 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
num_return_sequences=2,
|
num_return_sequences=2,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
@@ -547,7 +579,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate(self):
|
def test_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -556,6 +588,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -567,7 +600,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate_dict_output(self):
|
def test_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -575,6 +608,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
@@ -602,7 +636,7 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
@@ -618,6 +652,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
@@ -647,7 +682,7 @@ class GenerationTesterMixin:
|
|||||||
if model_class._no_split_modules is None:
|
if model_class._no_split_modules is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).eval()
|
model = model_class(config).eval()
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
@@ -659,12 +694,13 @@ class GenerationTesterMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=self.max_new_tokens,
|
max_new_tokens=self.max_new_tokens,
|
||||||
num_beams=2,
|
num_beams=2,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_sample_generate(self):
|
def test_beam_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -672,6 +708,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -680,28 +717,34 @@ class GenerationTesterMixin:
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
|
# for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
|
||||||
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
# no quick fix available, since obtaining image embeddings step is very model-specific
|
||||||
# code is up to date with our most recent standards
|
if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")):
|
||||||
if (
|
prepare_inputs_for_generation_args = set(
|
||||||
"inputs_embeds" in prepare_inputs_for_generation_args
|
inspect.signature(model.prepare_inputs_for_generation).parameters
|
||||||
and "cache_positions" in prepare_inputs_for_generation_args
|
|
||||||
):
|
|
||||||
input_embeds = model.get_input_embeddings()(input_ids)
|
|
||||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
|
||||||
output_generate2 = self._beam_sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
)
|
)
|
||||||
|
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||||
|
# code is up to date with our most recent standards
|
||||||
|
if (
|
||||||
|
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||||
|
and "cache_positions" in prepare_inputs_for_generation_args
|
||||||
|
):
|
||||||
|
input_embeds = model.get_input_embeddings()(input_ids)
|
||||||
|
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||||
|
output_generate2 = self._beam_sample_generate(
|
||||||
|
model=model,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict={},
|
||||||
|
beam_kwargs=beam_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_sample_generate_dict_output(self):
|
def test_beam_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -710,6 +753,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
@@ -736,7 +780,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _ = self._get_input_ids_and_config()
|
config, _, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
# if no bos token id => cannot generate from None
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
@@ -758,7 +802,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_group_beam_search_generate(self):
|
def test_group_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# check `generate()` and `group_beam_search()` are equal
|
# check `generate()` and `group_beam_search()` are equal
|
||||||
@@ -767,6 +811,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -781,6 +826,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -791,7 +837,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
@@ -799,6 +845,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
@@ -827,7 +874,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_constrained_beam_search_generate(self):
|
def test_constrained_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -845,6 +892,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
@@ -870,6 +918,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
@@ -885,7 +934,7 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_constrained_beam_search_generate_dict_output(self):
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -902,6 +951,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -937,7 +987,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -947,7 +997,11 @@ class GenerationTesterMixin:
|
|||||||
# test old generation output for backwards compatibility
|
# test old generation output for backwards compatibility
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._contrastive_generate(
|
output_generate = self._contrastive_generate(
|
||||||
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
|
model=model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -964,7 +1018,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -976,6 +1030,7 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
inputs_dict=inputs_dict,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -1003,7 +1058,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||||
self.skipTest(reason="TODO: fix me")
|
self.skipTest(reason="TODO: fix me")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1021,6 +1076,7 @@ class GenerationTesterMixin:
|
|||||||
low_memory=True,
|
low_memory=True,
|
||||||
max_new_tokens=self.max_new_tokens,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**inputs_dict,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1031,6 +1087,7 @@ class GenerationTesterMixin:
|
|||||||
low_memory=False,
|
low_memory=False,
|
||||||
max_new_tokens=self.max_new_tokens,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**inputs_dict,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
@@ -1055,7 +1112,7 @@ class GenerationTesterMixin:
|
|||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
|
config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
|
||||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
@@ -1065,7 +1122,12 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
low_output = model.generate(
|
low_output = model.generate(
|
||||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
|
input_ids,
|
||||||
|
max_new_tokens=8,
|
||||||
|
num_beams=5,
|
||||||
|
early_stopping=True,
|
||||||
|
low_memory=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
high_output = model.generate(
|
high_output = model.generate(
|
||||||
@@ -1114,7 +1176,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1140,7 +1202,9 @@ class GenerationTesterMixin:
|
|||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_greedy = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
# test with the same assistant model or randomly init one
|
# test with the same assistant model or randomly init one
|
||||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||||
@@ -1152,7 +1216,9 @@ class GenerationTesterMixin:
|
|||||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||||
generation_kwargs.update({"assistant_model": assistant_model})
|
generation_kwargs.update({"assistant_model": assistant_model})
|
||||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_assisted = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
|
|
||||||
@@ -1187,7 +1253,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1214,10 +1280,14 @@ class GenerationTesterMixin:
|
|||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_greedy = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
|
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
|
||||||
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_prompt_lookup = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
|
|
||||||
@@ -1239,7 +1309,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
|
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
|
||||||
|
|
||||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# Encoder-decoder models are not supported
|
# Encoder-decoder models are not supported
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
@@ -1267,7 +1337,7 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
generation_kwargs.update({"dola_layers": "low"})
|
generation_kwargs.update({"dola_layers": "low"})
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict)
|
||||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@@ -1296,7 +1366,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1326,9 +1396,11 @@ class GenerationTesterMixin:
|
|||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_assisted = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_assisted, input_ids, config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||||
@@ -1364,7 +1436,7 @@ class GenerationTesterMixin:
|
|||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
# We want to test only encoder-decoder models
|
# We want to test only encoder-decoder models
|
||||||
if not config.is_encoder_decoder:
|
if not config.is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
@@ -1394,6 +1466,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**{name: mask},
|
**{name: mask},
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
@@ -1416,7 +1489,7 @@ class GenerationTesterMixin:
|
|||||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||||
decoder_only_classes = []
|
decoder_only_classes = []
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, _, _ = self._get_input_ids_and_config()
|
config, _, _, _ = self._get_input_ids_and_config()
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@@ -1449,7 +1522,7 @@ class GenerationTesterMixin:
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
for model_class in decoder_only_classes:
|
for model_class in decoder_only_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
signature = inspect.signature(model.forward).parameters.keys()
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
@@ -1462,7 +1535,9 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# With left-padding (length 32)
|
# With left-padding (length 32)
|
||||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||||
pad_token_id = config.pad_token_id if getattr(config, "pad_token_id") is not None else 0
|
pad_token_id = (
|
||||||
|
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||||
|
)
|
||||||
pad_size = (input_ids.shape[0], 32)
|
pad_size = (input_ids.shape[0], 32)
|
||||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||||
@@ -1550,7 +1625,7 @@ class GenerationTesterMixin:
|
|||||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, _ = self._get_input_ids_and_config()
|
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# Ignore:
|
# Ignore:
|
||||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||||
@@ -1572,25 +1647,23 @@ class GenerationTesterMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Traditional way of generating text
|
# Traditional way of generating text
|
||||||
outputs_from_ids = model.generate(input_ids)
|
outputs_from_ids = model.generate(input_ids, max_new_tokens=5)
|
||||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
self.assertEqual(outputs_from_ids.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||||
|
|
||||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, max_new_tokens=5)
|
||||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs
|
# But if we pass different inputs_embeds, we should get different outputs
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
random_embeds = torch.rand_like(inputs_embeds)
|
||||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds, max_new_tokens=5)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|
||||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||||
outputs_from_embeds_wo_ids = model.generate(
|
outputs_from_embeds_wo_ids = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=5)
|
||||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
|
||||||
)
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||||
outputs_from_embeds_wo_ids.tolist(),
|
outputs_from_embeds_wo_ids.tolist(),
|
||||||
@@ -1607,7 +1680,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest(reason="This model does not support the static cache format")
|
self.skipTest(reason="This model does not support the static cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
@@ -1621,27 +1694,30 @@ class GenerationTesterMixin:
|
|||||||
max_cache_len = 30
|
max_cache_len = 30
|
||||||
|
|
||||||
# here we force to not stop at eos and go until max-length
|
# here we force to not stop at eos and go until max-length
|
||||||
model.generation_config.eos_token_id = model.config.eos_token_id = -1
|
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"max_length": max_cache_len,
|
"max_length": max_cache_len,
|
||||||
"cache_implementation": "static",
|
"cache_implementation": "static",
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
text_config = model.config.get_text_config()
|
||||||
head_dim = (
|
head_dim = (
|
||||||
model.config.head_dim
|
text_config.head_dim
|
||||||
if hasattr(model.config, "head_dim")
|
if hasattr(text_config, "head_dim")
|
||||||
else model.config.hidden_size // model.config.num_attention_heads
|
else text_config.hidden_size // text_config.num_attention_heads
|
||||||
)
|
)
|
||||||
num_key_value_heads = (
|
num_key_value_heads = (
|
||||||
model.config.num_attention_heads
|
text_config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(text_config, "num_key_value_heads", None) is None
|
||||||
else model.config.num_key_value_heads
|
else text_config.num_key_value_heads
|
||||||
)
|
)
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = text_config.num_hidden_layers
|
||||||
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
|
outputs = model.generate(
|
||||||
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
|
|
||||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||||
@@ -1742,7 +1818,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_cache_class:
|
if not model_class._supports_cache_class:
|
||||||
self.skipTest(reason="This model does not support the new cache format")
|
self.skipTest(reason="This model does not support the new cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -1757,7 +1833,9 @@ class GenerationTesterMixin:
|
|||||||
# Sets seed before calling `generate` for the case with do_sample=True
|
# Sets seed before calling `generate` for the case with do_sample=True
|
||||||
seed = torch.randint(0, 1000000, (1,)).item()
|
seed = torch.randint(0, 1000000, (1,)).item()
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
legacy_results = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||||
|
)
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
cache_cls = EncoderDecoderCache
|
cache_cls = EncoderDecoderCache
|
||||||
@@ -1766,7 +1844,11 @@ class GenerationTesterMixin:
|
|||||||
cache_cls = DynamicCache
|
cache_cls = DynamicCache
|
||||||
past_key_values = cache_cls()
|
past_key_values = cache_cls()
|
||||||
new_results = model.generate(
|
new_results = model.generate(
|
||||||
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
**generation_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||||
@@ -1810,7 +1892,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest(reason="This model does not support the static cache format")
|
self.skipTest(reason="This model does not support the static cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
@@ -1838,7 +1920,7 @@ class GenerationTesterMixin:
|
|||||||
else config.num_key_value_heads
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||||
@@ -1852,7 +1934,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_quantized_cache:
|
if not model_class._supports_quantized_cache:
|
||||||
self.skipTest(reason="This model does not support the quantized cache format")
|
self.skipTest(reason="This model does not support the quantized cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -1865,7 +1947,7 @@ class GenerationTesterMixin:
|
|||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
|
||||||
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
|
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
|
||||||
|
|
||||||
# passing past key values of different type should raise Error
|
# passing past key values of different type should raise Error
|
||||||
@@ -1931,7 +2013,7 @@ class GenerationTesterMixin:
|
|||||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
@@ -1946,10 +2028,12 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||||
with_all_logits = model.generate(
|
with_all_logits = model.generate(
|
||||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||||
)
|
)
|
||||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
without_all_logits = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
|
||||||
|
)
|
||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||||
|
|
||||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||||
@@ -1959,7 +2043,7 @@ class GenerationTesterMixin:
|
|||||||
if model_class._is_stateful:
|
if model_class._is_stateful:
|
||||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
@@ -1976,10 +2060,12 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||||
with_all_logits = model.generate(
|
with_all_logits = model.generate(
|
||||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||||
)
|
)
|
||||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
without_all_logits = model.generate(
|
||||||
|
input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
|
||||||
|
)
|
||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
|
|||||||
@@ -289,7 +289,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.attention_type = "original_full"
|
config.attention_type = "original_full"
|
||||||
|
|
||||||
input_ids = inputs_dict[self.input_name]
|
input_ids = inputs_dict.pop(self.input_name)
|
||||||
|
_ = inputs_dict.pop("attention_mask", None)
|
||||||
|
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||||
|
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
# cut to half length & take max batch_size 3
|
||||||
@@ -300,7 +303,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BigBirdPegasusModelTester(self)
|
self.model_tester = BigBirdPegasusModelTester(self)
|
||||||
|
|||||||
@@ -389,10 +389,6 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
|
self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="Bloom has a non-standard KV cache format.")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "bigscience/bigscience-small-testing"
|
model_name = "bigscience/bigscience-small-testing"
|
||||||
|
|||||||
@@ -450,6 +450,53 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
config_and_inputs[0].position_embedding_type = type
|
config_and_inputs[0].position_embedding_type = type
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def _check_attentions_for_generate(
|
||||||
|
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
|
):
|
||||||
|
# GIT attention shape depends on image inputs, overwrite
|
||||||
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
|
)
|
||||||
|
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||||
|
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||||
|
|
||||||
|
for idx, iter_attentions in enumerate(attentions):
|
||||||
|
tgt_len = min_length + idx + image_length if not use_cache else 1
|
||||||
|
src_len = min_length + idx + image_length
|
||||||
|
|
||||||
|
expected_shape = (
|
||||||
|
batch_size * num_beam_groups,
|
||||||
|
config.num_attention_heads,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
)
|
||||||
|
# check attn size
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_hidden_states_for_generate(
|
||||||
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
|
):
|
||||||
|
# GIT attention shape depends on image inputs, overwrite
|
||||||
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||||
|
[True] * len(hidden_states),
|
||||||
|
)
|
||||||
|
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||||
|
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||||
|
|
||||||
|
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||||
|
seq_len = min_length + idx + image_length if not use_cache else 1
|
||||||
|
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||||
|
# check hidden size
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||||
|
[expected_shape] * len(iter_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "microsoft/git-base"
|
model_name = "microsoft/git-base"
|
||||||
@@ -468,10 +515,18 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="GIT has pixel values as additional input")
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
|
|||||||
@@ -338,6 +338,14 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self.model_tester.check_global_attention(*config_and_inputs)
|
self.model_tester.check_global_attention(*config_and_inputs)
|
||||||
|
|
||||||
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
|
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
|
||||||
|
self, batch_size=batch_size
|
||||||
|
)
|
||||||
|
# LED computes attention scores based on mask indices if `is_global`
|
||||||
|
inputs_dict.pop("global_attention_mask")
|
||||||
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
# LEDForSequenceClassification does not support inputs_embeds
|
# LEDForSequenceClassification does not support inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ class LlavaVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 1,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
@@ -106,7 +107,7 @@ class LlavaVisionText2TextModelTester:
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -118,6 +119,8 @@ class LlavaVisionText2TextModelTester:
|
|||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 336
|
self.image_size = 336
|
||||||
self.encoder_seq_length = 231
|
self.encoder_seq_length = 231
|
||||||
|
self.num_image_tokens = 224
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return LlavaConfig(
|
return LlavaConfig(
|
||||||
@@ -128,6 +131,7 @@ class LlavaVisionText2TextModelTester:
|
|||||||
projector_hidden_act=self.projector_hidden_act,
|
projector_hidden_act=self.projector_hidden_act,
|
||||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||||
vision_feature_layer=self.vision_feature_layer,
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
|
image_seq_length=self.num_image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -148,8 +152,8 @@ class LlavaVisionText2TextModelTester:
|
|||||||
config, pixel_values = config_and_inputs
|
config, pixel_values = config_and_inputs
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
input_ids[:, 1] = config.image_token_index
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@@ -172,12 +176,13 @@ class LlavaVisionText2TextModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Model tester for `LlavaForConditionalGeneration`.
|
Model tester for `LlavaForConditionalGeneration`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
|
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|||||||
@@ -86,12 +86,12 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 1,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
"image_size": 16,
|
"image_size": 16,
|
||||||
"patch_size": 2,
|
"patch_size": 4,
|
||||||
"num_channels": 3,
|
"num_channels": 3,
|
||||||
"is_training": True,
|
"is_training": True,
|
||||||
"hidden_size": 32,
|
"hidden_size": 32,
|
||||||
@@ -112,7 +112,7 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -123,8 +123,10 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
self.batch_size = 3
|
self.batch_size = 3
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 30
|
self.image_size = 30
|
||||||
self.encoder_seq_length = 342
|
self.encoder_seq_length = 95
|
||||||
self.image_grid_pinpoints = [[32, 32]]
|
self.image_grid_pinpoints = [[32, 32]]
|
||||||
|
self.num_image_tokens = 88
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return LlavaNextConfig(
|
return LlavaNextConfig(
|
||||||
@@ -136,6 +138,7 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||||
vision_feature_layer=self.vision_feature_layer,
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
image_grid_pinpoints=self.image_grid_pinpoints,
|
image_grid_pinpoints=self.image_grid_pinpoints,
|
||||||
|
image_seq_length=self.num_image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -157,11 +160,10 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
config, pixel_values = config_and_inputs
|
config, pixel_values = config_and_inputs
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
input_ids[:, 1] = config.image_token_index
|
|
||||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
# maskout where the image token is
|
|
||||||
labels[:, 1] == self.ignore_index
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_sizes": torch.tensor(
|
"image_sizes": torch.tensor(
|
||||||
@@ -169,7 +171,6 @@ class LlavaNextVisionText2TextModelTester:
|
|||||||
),
|
),
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -214,6 +215,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -87,12 +87,12 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 2,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
"image_size": 16,
|
"image_size": 16,
|
||||||
"patch_size": 2,
|
"patch_size": 4,
|
||||||
"num_channels": 3,
|
"num_channels": 3,
|
||||||
"is_training": True,
|
"is_training": True,
|
||||||
"hidden_size": 32,
|
"hidden_size": 32,
|
||||||
@@ -114,7 +114,7 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -125,8 +125,11 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
self.batch_size = 3
|
self.batch_size = 3
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 30
|
self.image_size = 30
|
||||||
self.encoder_seq_length = 469
|
self.encoder_seq_length = 127
|
||||||
self.image_grid_pinpoints = [[32, 32]]
|
self.image_grid_pinpoints = [[32, 32]]
|
||||||
|
self.num_image_tokens = 88
|
||||||
|
self.num_video_tokens = 32
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return LlavaNextVideoConfig(
|
return LlavaNextVideoConfig(
|
||||||
@@ -139,6 +142,8 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||||
vision_feature_layer=self.vision_feature_layer,
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
image_grid_pinpoints=self.image_grid_pinpoints,
|
image_grid_pinpoints=self.image_grid_pinpoints,
|
||||||
|
video_seq_length=self.num_video_tokens,
|
||||||
|
image_seq_length=self.num_image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -168,13 +173,12 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
|
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
# we are giving 3 images and videos let's make sure we pass in 3 special tokens
|
|
||||||
input_ids[:, 1] = config.image_token_index
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
input_ids[:, 2] = config.video_token_index
|
input_ids[input_ids == config.video_token_index] = self.pad_token_id
|
||||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
# maskout where the image/video token is
|
input_ids[:, self.num_image_tokens : self.num_video_tokens + self.num_image_tokens] = config.video_token_index
|
||||||
labels[:, 1] == self.ignore_index
|
|
||||||
labels[:, 2] == self.ignore_index
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"pixel_values_videos": pixel_values_videos,
|
"pixel_values_videos": pixel_values_videos,
|
||||||
@@ -183,7 +187,6 @@ class LlavaNextVideoVisionText2TextModelTester:
|
|||||||
),
|
),
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -230,6 +233,7 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
image_token_index=0,
|
image_token_index=1,
|
||||||
projector_hidden_act="gelu",
|
projector_hidden_act="gelu",
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
vision_feature_select_strategy="full",
|
vision_feature_select_strategy="full",
|
||||||
@@ -92,7 +92,7 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
"image_size": 16,
|
"image_size": 16,
|
||||||
"patch_size": 2,
|
"patch_size": 8,
|
||||||
"num_channels": 3,
|
"num_channels": 3,
|
||||||
"is_training": True,
|
"is_training": True,
|
||||||
"hidden_size": 32,
|
"hidden_size": 32,
|
||||||
@@ -113,7 +113,9 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
self.num_image_tokens = 10
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -124,8 +126,7 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
self.batch_size = 3
|
self.batch_size = 3
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 30
|
self.image_size = 30
|
||||||
self.encoder_seq_length = 7
|
self.image_grid_pinpoints = [[16, 16]]
|
||||||
self.image_grid_pinpoints = [[32, 32]]
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return LlavaOnevisionConfig(
|
return LlavaOnevisionConfig(
|
||||||
@@ -143,7 +144,7 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
pixel_values = floats_tensor(
|
pixel_values = floats_tensor(
|
||||||
[
|
[
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
9,
|
3,
|
||||||
self.vision_config["num_channels"],
|
self.vision_config["num_channels"],
|
||||||
self.vision_config["image_size"],
|
self.vision_config["image_size"],
|
||||||
self.vision_config["image_size"],
|
self.vision_config["image_size"],
|
||||||
@@ -158,16 +159,16 @@ class LlavaOnevisionVisionText2TextModelTester:
|
|||||||
config, pixel_values = config_and_inputs
|
config, pixel_values = config_and_inputs
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
|
||||||
input_ids[:, 1] = config.image_token_index
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
|
|
||||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||||
# maskout where the image token is
|
labels[:, : self.num_image_tokens] == self.ignore_index
|
||||||
labels[:, 1] == self.ignore_index
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_sizes": torch.tensor(
|
"image_sizes": torch.tensor([[45, 45]] * self.batch_size),
|
||||||
[[self.vision_config["image_size"], self.vision_config["image_size"]]] * self.batch_size
|
|
||||||
),
|
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
|
|||||||
@@ -286,12 +286,19 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
_ = inputs_dict.pop("attention_mask", None)
|
||||||
|
inputs_dict = {
|
||||||
|
k: v[:batch_size, ...]
|
||||||
|
for k, v in inputs_dict.items()
|
||||||
|
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
# take max batch_size
|
# take max batch_size
|
||||||
sequence_length = input_ids.shape[-1]
|
sequence_length = input_ids.shape[-1]
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||||
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
@@ -299,7 +306,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
@@ -310,6 +317,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
inputs_dict={},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|||||||
@@ -289,12 +289,19 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
_ = inputs_dict.pop("attention_mask", None)
|
||||||
|
inputs_dict = {
|
||||||
|
k: v[:batch_size, ...]
|
||||||
|
for k, v in inputs_dict.items()
|
||||||
|
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
# take max batch_size
|
# take max batch_size
|
||||||
sequence_length = input_ids.shape[-1]
|
sequence_length = input_ids.shape[-1]
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||||
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
@@ -302,7 +309,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
@@ -313,6 +320,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
inputs_dict={},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ class PaliGemmaVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 1,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
@@ -115,6 +116,7 @@ class PaliGemmaVisionText2TextModelTester:
|
|||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.projection_dim = projection_dim
|
self.projection_dim = projection_dim
|
||||||
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -160,7 +162,7 @@ class PaliGemmaVisionText2TextModelTester:
|
|||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
|
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
|
||||||
# do not change this unless you modified image size or patch size
|
# do not change this unless you modified image size or patch size
|
||||||
input_ids = torch.where(input_ids == config.image_token_index, 2, input_ids)
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
input_ids[:, :16] = config.image_token_index
|
input_ids[:, :16] = config.image_token_index
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
@@ -173,12 +175,13 @@ class PaliGemmaVisionText2TextModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Model tester for `PaliGemmaForConditionalGeneration`.
|
Model tester for `PaliGemmaForConditionalGeneration`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -305,6 +308,12 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.Test
|
|||||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="VLMs doen't accept inputs embeds and pixel values at the same time. So if the test passed for bacbone LM, it passes for VLM also"
|
||||||
|
)
|
||||||
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -58,11 +58,11 @@ class Qwen2VLVisionText2TextModelTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=8,
|
batch_size=2,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
num_channels=3,
|
num_channels=3,
|
||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
image_size=28,
|
image_size=14,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
eos_token_id=1,
|
eos_token_id=1,
|
||||||
pad_token_id=2,
|
pad_token_id=2,
|
||||||
@@ -90,7 +90,7 @@ class Qwen2VLVisionText2TextModelTester:
|
|||||||
"mlp_ratio": 4,
|
"mlp_ratio": 4,
|
||||||
"num_heads": 4,
|
"num_heads": 4,
|
||||||
"patch_size": 14,
|
"patch_size": 14,
|
||||||
"spatial_merge_size": 2,
|
"spatial_merge_size": 1,
|
||||||
"temporal_patch_size": 2,
|
"temporal_patch_size": 2,
|
||||||
},
|
},
|
||||||
rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]},
|
rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]},
|
||||||
@@ -119,9 +119,10 @@ class Qwen2VLVisionText2TextModelTester:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.num_image_tokens = 32
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return Qwen2VLConfig(
|
return Qwen2VLConfig(
|
||||||
@@ -162,23 +163,19 @@ class Qwen2VLVisionText2TextModelTester:
|
|||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values = config_and_inputs
|
config, pixel_values = config_and_inputs
|
||||||
vision_seqlen = pixel_values.shape[0] // self.batch_size // (self.vision_config["spatial_merge_size"] ** 2)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length - 1 + vision_seqlen], self.vocab_size)
|
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||||
input_ids[:, torch.arange(vision_seqlen, device=torch_device) + 1] = self.image_token_id
|
input_ids[:, self.num_image_tokens] = self.image_token_id
|
||||||
labels = torch.zeros(
|
labels = torch.zeros(
|
||||||
(self.batch_size, self.seq_length - 1 + vision_seqlen),
|
(self.batch_size, self.seq_length),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
patch_size = self.vision_config["patch_size"]
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"image_grid_thw": torch.tensor(
|
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
|
||||||
[[1, self.image_size // patch_size, self.image_size // patch_size]] * self.batch_size
|
|
||||||
),
|
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
@@ -312,6 +309,12 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
def test_beam_search_low_memory(self):
|
def test_beam_search_low_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
|
||||||
|
)
|
||||||
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -689,12 +689,15 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||||
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict[self.input_name]
|
input_ids = inputs_dict.pop(self.input_name)
|
||||||
|
_ = inputs_dict.pop("attention_mask", None)
|
||||||
|
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||||
|
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||||
input_ids = input_ids[:batch_size, :16]
|
input_ids = input_ids[:batch_size, :16]
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
config.forced_eos_token_id = None
|
config.forced_eos_token_id = None
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
input_name = "input_features"
|
input_name = "input_features"
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, input_ids, attention_mask = GenerationTesterMixin._get_input_ids_and_config(self)
|
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(self)
|
||||||
|
|
||||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
# `input_ids` is actually `input_features` which is a 3D tensor.
|
||||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
||||||
@@ -294,7 +294,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
sequence_length = input_ids.shape[1]
|
sequence_length = input_ids.shape[1]
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
||||||
|
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Speech2TextModelTester(self)
|
self.model_tester = Speech2TextModelTester(self)
|
||||||
|
|||||||
@@ -75,14 +75,14 @@ class VideoLlavaVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 3,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
"model_type": "clip_vision_model",
|
"model_type": "clip_vision_model",
|
||||||
"batch_size": 12,
|
"batch_size": 12,
|
||||||
"image_size": 30,
|
"image_size": 30,
|
||||||
"patch_size": 2,
|
"patch_size": 6,
|
||||||
"num_channels": 3,
|
"num_channels": 3,
|
||||||
"is_training": True,
|
"is_training": True,
|
||||||
"hidden_size": 32,
|
"hidden_size": 32,
|
||||||
@@ -104,8 +104,8 @@ class VideoLlavaVisionText2TextModelTester:
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -116,7 +116,10 @@ class VideoLlavaVisionText2TextModelTester:
|
|||||||
self.batch_size = 5
|
self.batch_size = 5
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 224
|
self.image_size = 224
|
||||||
self.encoder_seq_length = 2044
|
self.encoder_seq_length = 64
|
||||||
|
self.num_image_tokens = 25
|
||||||
|
self.num_video_tokens = 26
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return VideoLlavaConfig(
|
return VideoLlavaConfig(
|
||||||
@@ -128,6 +131,8 @@ class VideoLlavaVisionText2TextModelTester:
|
|||||||
projector_hidden_act=self.projector_hidden_act,
|
projector_hidden_act=self.projector_hidden_act,
|
||||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||||
vision_feature_layer=self.vision_feature_layer,
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
|
image_seq_length=self.num_image_tokens,
|
||||||
|
video_seq_length=self.num_video_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -159,11 +164,11 @@ class VideoLlavaVisionText2TextModelTester:
|
|||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
|
||||||
# we are giving 3 videos and 3 images. Need to pass in image and video tokens, both
|
input_ids[(input_ids == config.image_token_index) | (input_ids == config.video_token_index)] = (
|
||||||
# also need to make sure no other special tokens are set
|
self.pad_token_id
|
||||||
input_ids[(input_ids == 0) | (input_ids == 1)] = 3
|
)
|
||||||
input_ids[:, 0] = config.video_token_index
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
input_ids[:, 1:2] = config.image_token_index
|
input_ids[:, self.num_image_tokens : self.num_video_tokens + self.num_image_tokens] = config.video_token_index
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values_videos": pixel_values_videos,
|
"pixel_values_videos": pixel_values_videos,
|
||||||
"pixel_values_images": pixel_values_images,
|
"pixel_values_images": pixel_values_images,
|
||||||
@@ -196,6 +201,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
@@ -242,16 +248,16 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
# if we remove some images from inputs leaving only one
|
# if we remove some images from inputs leaving only one
|
||||||
# image number mismatch error should raise
|
# image number mismatch error should raise
|
||||||
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
|
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = model(**inputs)
|
_ = model(**inputs)
|
||||||
|
|
||||||
def test_video_only_input(self):
|
def test_video_only_input(self):
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# replace video_token with dummy id which is not video token id
|
# replace image token id with dummy id
|
||||||
# error that video-tokens and num-of-video-inputs mismatch will be raised
|
# Error will be raised as num-image-tokens and num-of-image-embeds mismatch
|
||||||
inputs["input_ids"][:, 1:2] = 2
|
inputs["input_ids"][:, : self.model_tester.num_image_tokens] = 2
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = model(**inputs)
|
_ = model(**inputs)
|
||||||
|
|
||||||
@@ -262,8 +268,13 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# set dummy id, which is not image token id, same as above
|
# set dummy id, which is not video token id
|
||||||
inputs["input_ids"][:, :1] = 2
|
# Error will be raised as num-video-tokens and num-of-video-embeds mismatch
|
||||||
|
inputs["input_ids"][
|
||||||
|
:,
|
||||||
|
self.model_tester.num_image_tokens : self.model_tester.num_image_tokens
|
||||||
|
+ self.model_tester.num_video_tokens,
|
||||||
|
] = 2
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = model(**inputs)
|
_ = model(**inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
"initializer_range": 0.02,
|
"initializer_range": 0.02,
|
||||||
"num_labels": 3,
|
"num_labels": 3,
|
||||||
"num_choices": 4,
|
"num_choices": 4,
|
||||||
"pad_token_id": 0,
|
"pad_token_id": 1,
|
||||||
},
|
},
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config={
|
vision_config={
|
||||||
@@ -99,7 +100,7 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
self.vision_feature_layers = vision_feature_layers
|
self.vision_feature_layers = vision_feature_layers
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
self.seq_length = seq_length
|
self.pad_token_id = text_config["pad_token_id"]
|
||||||
|
|
||||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
self.vocab_size = text_config["vocab_size"]
|
self.vocab_size = text_config["vocab_size"]
|
||||||
@@ -111,6 +112,8 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.image_size = 336
|
self.image_size = 336
|
||||||
self.encoder_seq_length = 231
|
self.encoder_seq_length = 231
|
||||||
|
self.num_image_tokens = 224
|
||||||
|
self.seq_length = seq_length + self.num_image_tokens
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return VipLlavaConfig(
|
return VipLlavaConfig(
|
||||||
@@ -120,6 +123,7 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
image_token_index=self.image_token_index,
|
image_token_index=self.image_token_index,
|
||||||
projector_hidden_act=self.projector_hidden_act,
|
projector_hidden_act=self.projector_hidden_act,
|
||||||
vision_feature_layers=self.vision_feature_layers,
|
vision_feature_layers=self.vision_feature_layers,
|
||||||
|
image_seq_length=self.num_image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -140,8 +144,9 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
config, pixel_values = config_and_inputs
|
config, pixel_values = config_and_inputs
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
|
||||||
input_ids[:, 1] = config.image_token_index
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
|
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@@ -152,12 +157,13 @@ class VipLlavaVisionText2TextModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from transformers.tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest with Llava->VipLlava
|
# Copied from transformers.tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest with Llava->VipLlava
|
||||||
class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Model tester for `VipLlavaForConditionalGeneration`.
|
Model tester for `VipLlavaForConditionalGeneration`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
|||||||
@@ -497,19 +497,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=3):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict[self.input_name]
|
|
||||||
|
|
||||||
# cut to half length & take max batch_size=batch_size
|
|
||||||
input_ids = input_ids[:batch_size, :, :]
|
|
||||||
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
|
||||||
config.pad_token_id = config.eos_token_id
|
|
||||||
|
|
||||||
return config, input_ids, None
|
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -4744,7 +4744,7 @@ class ModelTesterMixin:
|
|||||||
output_logits=True,
|
output_logits=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
|
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
# For now, Let's focus only on GPU for `torch.compile`
|
# For now, Let's focus only on GPU for `torch.compile`
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user