Add CSM model (#36719)
* draft structure * depth decoder with forward pre hook * full model forward draft * draft update * depth decoder update * ConversationalSpeechModelForCausalLM udpates * add generate * max length criteria small fix * udpate * updates * generation update * update in loss compute * conversion script * update for correct input embeddings * handle interleaved rope * update * update * update * support compile * update training * add doc * update doc * correct inits * ConversationalSpeechModel -> Csm * conf update * name update * tests CsmForCausalLMTest * convert use cached_file * conf + modeling updates * generate utils handle third dim shape * integration test * modeling + conf updates * common test handle more than 2 dims * add nested audio list utils * processing handle nested audio list * csm processing draft * mimi util * init updates * modular update * convert modular * processing update * csm tests update * generate tests handle third dim * generate utils handle third dim * propagate _get_initial_cache_position update * tied_weight_keys update + convert correctly * fix inputs_embeds * revert audio nested list * batch inference update + return audio * audio_utils update * processor update * some more integration tests * remove old test * porcessing output labels * improve * fix * update rope values with equivalent ones * conversion update * udpate tests * handle depth decoder generation config * remove default eos_token_id * make style * revert modeling_mimi * add default generation_config * remove sdpa since handled by default * make * fix conflict * fix conflicts * correct naming * correct imports * make * causal -> conditional naming * causal -> conditional naming * auto update * make * make * add doc * test update * fix weight init * audio tokens offsets as buffer * 4d mask in conditional class * make * doc update * fix causal mask * fix causal mask * doc update * doc update * add processor doc * update doc * fix 4d causal mask * update make_list_of_audio * do not default to mutable * remove duplicates * remove useless reset_parameters * use GradientCheckpointingLayer * use can_return_tuple * formatting * prepend placeholder in _sample * torch compile fix * some more fixies * convert modular * fix * default max_length in convert * handle depth decoder generation config correctly * clearer formulation * handle output_loading_info * handle softmax warning * add doc * propagate _get_initial_cache_position changes * generation in its own module * add processor tests * fix compile witu cuda graphs * fix compile with cuda graphs * add csm.md * include CSM loss * doc nit * doc nit * doc nit * Update docs/source/en/model_doc/csm.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add save_audio to processor * Update src/transformers/models/csm/modular_csm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * doc update * simplify audio_codes_mask computation * doc update * simplify loss computation * fix static cache test * fix * remove comment * simplify encoded length computation * use hf-internal-testing * doc update * cast to float before numpy * nit * mem efficient codebook head * nit * cat input values with cutoffs --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -501,9 +501,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
@@ -525,13 +525,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -565,10 +565,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
@@ -582,9 +582,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate_dict_output(self):
|
||||
@@ -607,13 +607,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -632,9 +632,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
@@ -657,13 +657,13 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -706,10 +706,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(
|
||||
@@ -759,9 +759,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
@@ -786,13 +786,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -840,9 +840,9 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
@@ -853,9 +853,9 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
@@ -878,13 +878,13 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -923,9 +923,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
@@ -947,9 +947,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
@@ -987,13 +987,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -1031,9 +1031,9 @@ class GenerationTesterMixin:
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
@@ -1067,10 +1067,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
@@ -1499,7 +1499,7 @@ class GenerationTesterMixin:
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
@@ -1525,10 +1525,12 @@ class GenerationTesterMixin:
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
padded_attention_mask = torch.cat(
|
||||
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
|
||||
)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
@@ -1587,7 +1589,7 @@ class GenerationTesterMixin:
|
||||
else text_config.num_attention_heads
|
||||
)
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
|
||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
||||
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
|
||||
default_cross_attention_shape = (
|
||||
@@ -1606,7 +1608,7 @@ class GenerationTesterMixin:
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
batch_size, seq_length = inputs["input_ids"].shape[:2]
|
||||
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||
all_cache_shapes = [
|
||||
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
|
||||
@@ -1727,7 +1729,7 @@ class GenerationTesterMixin:
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
@@ -2262,11 +2264,11 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
@@ -2408,7 +2410,7 @@ class GenerationTesterMixin:
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
generated_length = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
|
||||
output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length
|
||||
)
|
||||
decoder_past_key_values = getattr(output, "past_key_values", None)
|
||||
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
|
||||
@@ -2441,7 +2443,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.decoder_attentions,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
@@ -2450,7 +2452,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.attentions,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
@@ -2469,7 +2471,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.decoder_hidden_states,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@@ -2478,7 +2480,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.hidden_states,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@@ -2506,7 +2508,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
if has_standard_cache:
|
||||
if use_cache:
|
||||
cache_length = output.sequences.shape[-1] - 1
|
||||
cache_length = output.sequences.shape[1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
batch_size=internal_batch_size,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
|
||||
Reference in New Issue
Block a user