diff --git a/examples/pytorch/text-generation/run_generation.py b/examples/pytorch/text-generation/run_generation.py index 9b4b09fc96..e0dda0ec0c 100755 --- a/examples/pytorch/text-generation/run_generation.py +++ b/examples/pytorch/text-generation/run_generation.py @@ -20,6 +20,7 @@ import argparse import logging +from typing import Tuple import numpy as np import torch @@ -27,6 +28,7 @@ import torch from transformers import ( CTRLLMHeadModel, CTRLTokenizer, + GenerationMixin, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTLMHeadModel, @@ -38,6 +40,7 @@ from transformers import ( XLNetLMHeadModel, XLNetTokenizer, ) +from transformers.modeling_outputs import CausalLMOutputWithPast logging.basicConfig( @@ -151,6 +154,131 @@ def adjust_length_to_model(length, max_sequence_length): return length +def sparse_model_config(model_config): + embedding_size = None + if hasattr(model_config, "hidden_size"): + embedding_size = model_config.hidden_size + elif hasattr(model_config, "n_embed"): + embedding_size = model_config.n_embed + elif hasattr(model_config, "n_embd"): + embedding_size = model_config.n_embd + + num_head = None + if hasattr(model_config, "num_attention_heads"): + num_head = model_config.num_attention_heads + elif hasattr(model_config, "n_head"): + num_head = model_config.n_head + + if embedding_size is None or num_head is None or num_head == 0: + raise ValueError("Check the model config") + + num_embedding_size_per_head = int(embedding_size / num_head) + num_layer = model_config.n_layer + + return num_layer, num_head, num_embedding_size_per_head + + +def prepare_jit_inputs(inputs, model, tokenizer): + num_batch = len(inputs) + dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) + num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) + if model.config.model_type == "bloom": + past_key_values = tuple( + ( + torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1) + .to(model.config.torch_dtype) + .to(model.device), + torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head) + .to(model.config.torch_dtype) + .to(model.device), + ) + for _ in range(num_block_layers) + ) + else: + past_key_values = tuple( + ( + torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) + .to(model.config.torch_dtype) + .to(model.device), + torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) + .to(model.config.torch_dtype) + .to(model.device), + ) + for _ in range(num_block_layers) + ) + + dummy_input["attention_mask"] = torch.cat( + [ + torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype), + dummy_input["attention_mask"], + ], + -1, + ) + + if model.config.use_cache: + jit_inputs = ( + dummy_input["input_ids"].to(model.device), + past_key_values, + dummy_input["attention_mask"].to(model.device), + ) + else: + jit_inputs = ( + dummy_input["input_ids"].to(model.device), + dummy_input["attention_mask"].to(model.device), + ) + + return jit_inputs + + +class _ModelFallbackWrapper(GenerationMixin): + __slots__ = ("_optimized", "_default") + + def __init__(self, optimized, default): + self._optimized = optimized + self._default = default + + def __call__(self, *args, **kwargs): + if kwargs["past_key_values"] is None: + return self._default(*args, **kwargs) + trace_graph_inputs = [] + kwargs.pop("position_ids", None) + for k, v in kwargs.items(): + if v is not None and not isinstance(v, bool): + trace_graph_inputs.append(v) + trace_graph_inputs = tuple(trace_graph_inputs) + outputs = self._optimized(*trace_graph_inputs) + lm_logits = outputs[0] + past_key_values = outputs[1] + fixed_output = CausalLMOutputWithPast( + loss=None, + logits=lm_logits, + past_key_values=past_key_values, + hidden_states=None, + attentions=None, + ) + return fixed_output + + def __getattr__(self, item): + return getattr(self._default, item) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs + ): + return self._default.prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs + ) + + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or + [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return self._default._reorder_cache(past_key_values, beam_idx) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -196,6 +324,9 @@ def main(): action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) + parser.add_argument( + "--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference" + ) args = parser.parse_args() args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -213,6 +344,8 @@ def main(): raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) @@ -248,6 +381,18 @@ def main(): else: input_ids = encoded_prompt + if args.jit: + jit_input_texts = ["jit"] + jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) + torch._C._jit_set_texpr_fuser_enabled(False) + model.config.return_dict = False + traced_model = torch.jit.trace(model, jit_inputs, strict=False) + traced_model = torch.jit.freeze(traced_model.eval()) + traced_model(*jit_inputs) + traced_model(*jit_inputs) + + model = _ModelFallbackWrapper(traced_model, model) + output_sequences = model.generate( input_ids=input_ids, max_length=args.length + len(encoded_prompt[0]),