Enable traced model for text-generation task (#22265)
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -27,6 +28,7 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CTRLLMHeadModel,
|
CTRLLMHeadModel,
|
||||||
CTRLTokenizer,
|
CTRLTokenizer,
|
||||||
|
GenerationMixin,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
OpenAIGPTLMHeadModel,
|
OpenAIGPTLMHeadModel,
|
||||||
@@ -38,6 +40,7 @@ from transformers import (
|
|||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -151,6 +154,131 @@ def adjust_length_to_model(length, max_sequence_length):
|
|||||||
return length
|
return length
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_model_config(model_config):
|
||||||
|
embedding_size = None
|
||||||
|
if hasattr(model_config, "hidden_size"):
|
||||||
|
embedding_size = model_config.hidden_size
|
||||||
|
elif hasattr(model_config, "n_embed"):
|
||||||
|
embedding_size = model_config.n_embed
|
||||||
|
elif hasattr(model_config, "n_embd"):
|
||||||
|
embedding_size = model_config.n_embd
|
||||||
|
|
||||||
|
num_head = None
|
||||||
|
if hasattr(model_config, "num_attention_heads"):
|
||||||
|
num_head = model_config.num_attention_heads
|
||||||
|
elif hasattr(model_config, "n_head"):
|
||||||
|
num_head = model_config.n_head
|
||||||
|
|
||||||
|
if embedding_size is None or num_head is None or num_head == 0:
|
||||||
|
raise ValueError("Check the model config")
|
||||||
|
|
||||||
|
num_embedding_size_per_head = int(embedding_size / num_head)
|
||||||
|
num_layer = model_config.n_layer
|
||||||
|
|
||||||
|
return num_layer, num_head, num_embedding_size_per_head
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_jit_inputs(inputs, model, tokenizer):
|
||||||
|
num_batch = len(inputs)
|
||||||
|
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
|
||||||
|
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
|
||||||
|
if model.config.model_type == "bloom":
|
||||||
|
past_key_values = tuple(
|
||||||
|
(
|
||||||
|
torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1)
|
||||||
|
.to(model.config.torch_dtype)
|
||||||
|
.to(model.device),
|
||||||
|
torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head)
|
||||||
|
.to(model.config.torch_dtype)
|
||||||
|
.to(model.device),
|
||||||
|
)
|
||||||
|
for _ in range(num_block_layers)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
past_key_values = tuple(
|
||||||
|
(
|
||||||
|
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
|
||||||
|
.to(model.config.torch_dtype)
|
||||||
|
.to(model.device),
|
||||||
|
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
|
||||||
|
.to(model.config.torch_dtype)
|
||||||
|
.to(model.device),
|
||||||
|
)
|
||||||
|
for _ in range(num_block_layers)
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_input["attention_mask"] = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype),
|
||||||
|
dummy_input["attention_mask"],
|
||||||
|
],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model.config.use_cache:
|
||||||
|
jit_inputs = (
|
||||||
|
dummy_input["input_ids"].to(model.device),
|
||||||
|
past_key_values,
|
||||||
|
dummy_input["attention_mask"].to(model.device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
jit_inputs = (
|
||||||
|
dummy_input["input_ids"].to(model.device),
|
||||||
|
dummy_input["attention_mask"].to(model.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
return jit_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelFallbackWrapper(GenerationMixin):
|
||||||
|
__slots__ = ("_optimized", "_default")
|
||||||
|
|
||||||
|
def __init__(self, optimized, default):
|
||||||
|
self._optimized = optimized
|
||||||
|
self._default = default
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if kwargs["past_key_values"] is None:
|
||||||
|
return self._default(*args, **kwargs)
|
||||||
|
trace_graph_inputs = []
|
||||||
|
kwargs.pop("position_ids", None)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if v is not None and not isinstance(v, bool):
|
||||||
|
trace_graph_inputs.append(v)
|
||||||
|
trace_graph_inputs = tuple(trace_graph_inputs)
|
||||||
|
outputs = self._optimized(*trace_graph_inputs)
|
||||||
|
lm_logits = outputs[0]
|
||||||
|
past_key_values = outputs[1]
|
||||||
|
fixed_output = CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
return fixed_output
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return getattr(self._default, item)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
|
||||||
|
):
|
||||||
|
return self._default.prepare_inputs_for_generation(
|
||||||
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reorder_cache(
|
||||||
|
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||||
|
) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
||||||
|
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||||
|
beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return self._default._reorder_cache(past_key_values, beam_idx)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -196,6 +324,9 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
@@ -213,6 +344,8 @@ def main():
|
|||||||
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
model = model_class.from_pretrained(args.model_name_or_path)
|
model = model_class.from_pretrained(args.model_name_or_path)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
@@ -248,6 +381,18 @@ def main():
|
|||||||
else:
|
else:
|
||||||
input_ids = encoded_prompt
|
input_ids = encoded_prompt
|
||||||
|
|
||||||
|
if args.jit:
|
||||||
|
jit_input_texts = ["jit"]
|
||||||
|
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||||
|
model.config.return_dict = False
|
||||||
|
traced_model = torch.jit.trace(model, jit_inputs, strict=False)
|
||||||
|
traced_model = torch.jit.freeze(traced_model.eval())
|
||||||
|
traced_model(*jit_inputs)
|
||||||
|
traced_model(*jit_inputs)
|
||||||
|
|
||||||
|
model = _ModelFallbackWrapper(traced_model, model)
|
||||||
|
|
||||||
output_sequences = model.generate(
|
output_sequences = model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_length=args.length + len(encoded_prompt[0]),
|
max_length=args.length + len(encoded_prompt[0]),
|
||||||
|
|||||||
Reference in New Issue
Block a user