|
|
|
@@ -1391,43 +1391,6 @@ class GenerationMixin:
|
|
|
|
UserWarning,
|
|
|
|
UserWarning,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
|
|
|
key = "decoder_attention_mask"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
key = "attention_mask"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if key not in model_kwargs:
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = model_kwargs[key]
|
|
|
|
|
|
|
|
mask_extension_length = new_mask_length - mask.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask_extension_length < 0:
|
|
|
|
|
|
|
|
raise ValueError("Cannot extend attention mask to a length less than it already is")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_kwargs[key] = torch.cat(
|
|
|
|
|
|
|
|
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
|
|
|
|
|
|
|
|
dim=-1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
|
|
|
|
|
|
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
|
|
|
|
|
|
|
extension_length = new_length - token_type_ids.shape[1]
|
|
|
|
|
|
|
|
token_type_copies = final_token_type.repeat(1, extension_length)
|
|
|
|
|
|
|
|
model_kwargs["token_type_ids"] = torch.cat(
|
|
|
|
|
|
|
|
[model_kwargs["token_type_ids"], token_type_copies],
|
|
|
|
|
|
|
|
dim=-1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def generate(
|
|
|
|
def generate(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
@@ -4505,11 +4468,6 @@ class GenerationMixin:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
|
|
|
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
|
|
|
|
|
|
|
|
|
|
|
# check if assistant model accepts encoder_outputs
|
|
|
|
|
|
|
|
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
|
|
|
|
|
|
|
inspect.signature(assistant_model.forward).parameters.keys()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# init values
|
|
|
|
# init values
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
|
|
@@ -4547,20 +4505,32 @@ class GenerationMixin:
|
|
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# prepare assistant model's keys of inputs
|
|
|
|
|
|
|
|
assistant_kwargs = copy.copy(model_kwargs)
|
|
|
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
|
|
|
# both are encoder-decoder
|
|
|
|
|
|
|
|
input_ids_key = "decoder_input_ids"
|
|
|
|
|
|
|
|
attention_key = "decoder_attention_mask"
|
|
|
|
|
|
|
|
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
|
|
|
|
|
|
|
|
elif "assistant_encoder_outputs" in assistant_kwargs:
|
|
|
|
|
|
|
|
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
|
|
|
|
|
|
|
|
input_ids_key = "input_ids"
|
|
|
|
|
|
|
|
attention_key = "attention_mask"
|
|
|
|
|
|
|
|
assistant_kwargs["attention_mask"] = assistant_kwargs.get(
|
|
|
|
|
|
|
|
"decoder_attention_mask",
|
|
|
|
|
|
|
|
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# both are decoder-only
|
|
|
|
|
|
|
|
input_ids_key = "input_ids"
|
|
|
|
|
|
|
|
attention_key = "attention_mask"
|
|
|
|
|
|
|
|
|
|
|
|
# keep track of which sequences are already finished
|
|
|
|
# keep track of which sequences are already finished
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
|
|
|
|
|
|
|
|
# other auxiliary variables
|
|
|
|
# other auxiliary variables
|
|
|
|
max_len = stopping_criteria[0].max_length
|
|
|
|
max_len = stopping_criteria[0].max_length
|
|
|
|
assistant_kv_indexing = (
|
|
|
|
|
|
|
|
1
|
|
|
|
|
|
|
|
if "bloom" in assistant_model.__class__.__name__.lower()
|
|
|
|
|
|
|
|
or (
|
|
|
|
|
|
|
|
assistant_model.config.architectures is not None
|
|
|
|
|
|
|
|
and "bloom" in assistant_model.config.architectures[0].lower()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else 0
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this_peer_finished = False # used by synced_gpus only
|
|
|
|
this_peer_finished = False # used by synced_gpus only
|
|
|
|
while True:
|
|
|
|
while True:
|
|
|
|
@@ -4582,44 +4552,21 @@ class GenerationMixin:
|
|
|
|
# need access to the assistant cache to secure strong speedups.
|
|
|
|
# need access to the assistant cache to secure strong speedups.
|
|
|
|
candidate_input_ids = input_ids
|
|
|
|
candidate_input_ids = input_ids
|
|
|
|
for _ in range(int(num_assistant_tokens)):
|
|
|
|
for _ in range(int(num_assistant_tokens)):
|
|
|
|
# 1.1. use the assistant model to obtain the next candidate logits
|
|
|
|
# 1.1 prepare assistant model inputs
|
|
|
|
if "assistant_past_key_values" in model_kwargs:
|
|
|
|
assistant_inputs = assistant_model.prepare_inputs_for_generation(
|
|
|
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
|
|
|
candidate_input_ids,
|
|
|
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
|
|
|
**assistant_kwargs,
|
|
|
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
|
|
|
)
|
|
|
|
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
|
|
|
|
|
|
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
|
|
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
|
|
|
|
decoder_input_ids=assist_inputs,
|
|
|
|
|
|
|
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
|
|
|
|
|
|
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
encoder_kwargs = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
|
|
|
# 1.2. check if the input ids length is correct
|
|
|
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
|
|
|
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
|
|
|
|
|
|
|
|
if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2):
|
|
|
|
|
|
|
|
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
|
|
|
|
|
|
|
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
# 1.3. use the assistant model to obtain the next candidate logits
|
|
|
|
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
|
|
|
|
assistant_model_outputs = assistant_model(**assistant_inputs)
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
|
|
|
|
decoder_input_ids=candidate_input_ids,
|
|
|
|
|
|
|
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
encoder_kwargs = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
|
|
|
# 1.4. greedily select the next candidate token
|
|
|
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1.2. greedily select the next candidate token
|
|
|
|
|
|
|
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
|
|
|
|
|
|
|
if len(logits_processor) > 0:
|
|
|
|
if len(logits_processor) > 0:
|
|
|
|
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
|
|
|
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
|
|
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
|
|
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
|
|
|
@@ -4627,7 +4574,13 @@ class GenerationMixin:
|
|
|
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
|
|
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
|
|
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
|
|
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
# 1.3. stop assistant generation on EOS
|
|
|
|
# 1.5. update assistant model inputs
|
|
|
|
|
|
|
|
if assistant_kwargs.get(attention_key, None) is not None:
|
|
|
|
|
|
|
|
mask = assistant_kwargs[attention_key]
|
|
|
|
|
|
|
|
assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
|
|
|
|
|
|
|
|
assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1.6. stop assistant generation on EOS
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
|
|
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
|
|
|
last_assistant_token_is_eos = (
|
|
|
|
last_assistant_token_is_eos = (
|
|
|
|
@@ -4646,8 +4599,10 @@ class GenerationMixin:
|
|
|
|
|
|
|
|
|
|
|
|
# 2.1. Prepare the model inputs
|
|
|
|
# 2.1. Prepare the model inputs
|
|
|
|
candidate_kwargs = copy.copy(model_kwargs)
|
|
|
|
candidate_kwargs = copy.copy(model_kwargs)
|
|
|
|
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
|
|
|
|
candidate_kwargs = _prepare_attention_mask(
|
|
|
|
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
|
|
|
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -4699,8 +4654,8 @@ class GenerationMixin:
|
|
|
|
# 5.3. Discard past key values relative to unused assistant tokens
|
|
|
|
# 5.3. Discard past key values relative to unused assistant tokens
|
|
|
|
new_cache_size = new_cur_len - 1
|
|
|
|
new_cache_size = new_cur_len - 1
|
|
|
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
|
|
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
|
|
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
|
|
|
assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
|
|
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
|
|
|
|
assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1
|
|
|
|
) # the assistant does not have the token after the last match, hence the -1
|
|
|
|
) # the assistant does not have the token after the last match, hence the -1
|
|
|
|
|
|
|
|
|
|
|
|
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
|
|
|
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
|
|
|
@@ -4761,6 +4716,12 @@ class GenerationMixin:
|
|
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Update assistant_kwargs for the assistant's next round of generations
|
|
|
|
|
|
|
|
assistant_kwargs = _prepare_attention_mask(
|
|
|
|
|
|
|
|
assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)
|
|
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
@@ -4938,3 +4899,37 @@ def _ranking_fast(
|
|
|
|
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
|
|
|
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
|
|
|
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
|
|
|
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
|
|
|
return selected_idx
|
|
|
|
return selected_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
|
|
|
|
|
|
|
|
if mask_key not in model_kwargs:
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = model_kwargs[mask_key]
|
|
|
|
|
|
|
|
mask_length_diff = new_length - mask.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask_length_diff < 0:
|
|
|
|
|
|
|
|
model_kwargs[mask_key] = mask[:, :mask_length_diff]
|
|
|
|
|
|
|
|
elif mask_length_diff > 0:
|
|
|
|
|
|
|
|
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
|
|
|
|
|
|
|
|
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
|
|
|
|
|
|
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
|
|
|
|
|
|
|
type_length_diff = new_length - token_type_ids.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if type_length_diff < 0:
|
|
|
|
|
|
|
|
token_type_ids = token_type_ids[:, :type_length_diff]
|
|
|
|
|
|
|
|
elif type_length_diff > 0:
|
|
|
|
|
|
|
|
token_type_copies = final_token_type.repeat(1, type_length_diff)
|
|
|
|
|
|
|
|
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|