From 279ce5b705a0b8689f2a8e5d5258dbb5421c9e6c Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Thu, 7 Oct 2021 18:07:02 +0800 Subject: [PATCH] Add an example of exporting BartModel + BeamSearch to ONNX module. (#13765) * Add all example files. * Reformat files by black. * Style. * Remove unused imports. Co-authored-by: Morgan Funtowicz --- .../translation/bart_onnx/generation_onnx.py | 817 ++++++++++++++++++ .../translation/bart_onnx/reduce_onnx_size.py | 113 +++ .../onnx/pytorch/translation/requirements.txt | 1 + .../pytorch/translation/run_onnx_exporter.py | 216 +++++ 4 files changed, 1147 insertions(+) create mode 100644 examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py create mode 100644 examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py create mode 100644 examples/onnx/pytorch/translation/requirements.txt create mode 100644 examples/onnx/pytorch/translation/run_onnx_exporter.py diff --git a/examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py b/examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py new file mode 100644 index 0000000000..ac2f6a9434 --- /dev/null +++ b/examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py @@ -0,0 +1,817 @@ +import copy +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from transformers import BartConfig +from transformers.generation_utils import GenerationMixin + + +def flatten_list(past): + values = [] + if past is not None: + for i, p in enumerate(past): + for j, q in enumerate(p): + values.append(q) + + return values + + +def list_to_tuple(past): + results = () + temp_result = () + count_n = len(past) // 4 + for idx in range(count_n): + real_idx = idx * 4 + temp_result = tuple(past[real_idx : real_idx + 4]) + results += ((temp_result),) + + return results + + +class EncoderForONNX(torch.nn.Module): + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + + def forward(self, input_ids, attention_mask): + return self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + + +class DecoderForONNX(torch.nn.Module): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def forward(self, input_ids, encoder_state, attention_mask, past=None): + all_results = None + if past is not None: + all_results = list_to_tuple(past) + input_ids = input_ids[:, -1:] + + last_hidden_state, past_key_values = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_state, + encoder_attention_mask=attention_mask, + past_key_values=all_results, + return_dict=False, + ) + + past_values = [] + for past in past_key_values: + past_values = past_values + list(past) + return last_hidden_state, past_values + + +def create_traced_encoder(encoder, input_ids, attention_mask): + encoder_c = copy.deepcopy(encoder) + encoder_for_onnx = EncoderForONNX(encoder_c) + + # return torch.jit.trace(encoder, (input_ids, attention_mask)) + return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask)) + + +def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None): + decoder_c = copy.deepcopy(decoder) + decoder_for_onnx = DecoderForONNX(decoder_c) + past_values = flatten_list(past) + + # Do this twice so we got 2 different decoders for further work. + if past_values is None or len(past_values) == 0: + return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask)) + else: + return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values)) + + +class BartConfigTS(BartConfig, torch.nn.Module): + def init_module(self): + torch.nn.Module.__init__(self) + + +class MinLengthLogitsProcessorTS(torch.nn.Module): + r""" + :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + super().__init__() + + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def forward(self, input_ids, scores) -> torch.Tensor: + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + scores[:, self.eos_token_id] = -float("inf") + return scores + + +class BARTGenerator(torch.nn.Module, GenerationMixin): + def __init__(self, model): + super().__init__() + self.config = BartConfigTS(model.config) + self.config.init_module() + self.config.force_bos_token_to_be_generated = False + self._trace_modules(model) + self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id) + self.final_logits_weight = model.model.shared.weight + self.final_logits_bias = model.final_logits_bias + self.decoder_layers = model.config.decoder_layers + + def _trace_modules(self, model): + # Be aware of the last one 2 should be kept. + input_ids = torch.tensor( + [ + [ + 19, + 669, + 18, + 420, + 8, + 664, + 57, + 42, + 8, + 664, + 21, + 3028, + 195, + 4445, + 331, + 1293, + 34, + 21, + 10, + 6174, + 1100, + 6, + 69, + 104, + 42, + 32, + 2621, + 1638, + 144, + 4, + 6174, + 558, + 108, + 4419, + 1091, + 28, + 4, + 1668, + 9, + 1509, + 1621, + 279, + 35, + 867, + 2734, + 85, + 11, + 2216, + 2734, + 85, + 203, + 2244, + 7, + 6, + 15, + 8102, + 7, + 57, + 8629, + 5, + 2, + ] + ], + device=model.device, + dtype=torch.long, + ) + attention_mask = torch.tensor( + [ + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ], + device=model.device, + dtype=torch.bool, + ) + self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask) + encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True) + decoder = model.model.decoder + decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None) + self.decoder_no_past = create_traced_decoder( + model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask + ) + self.decoder_with_past = create_traced_decoder( + model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1] + ) + + def _encoder_forward(self, input_ids, attention_mask): + return self.encoder(input_ids, attention_mask)[0] + + @staticmethod + def _init_sequence_length_for_generation( + input_ids: torch.LongTensor, max_length: int + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + 1 + sequence_lengths = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + max_length + + cur_len = input_ids.shape[-1] + return sequence_lengths, unfinished_sequences, cur_len + + def _decoder_forward(self, input_ids, encoder_output, attention_mask, past: List[torch.Tensor]): + # Update here to use different decoder for different values of past. + if past is None or len(past) == 0: + decoder_output, past = self.decoder_no_past( + input_ids=input_ids, encoder_state=encoder_output, attention_mask=attention_mask + ) + else: + decoder_output, past = self.decoder_with_past( + input_ids=input_ids, encoder_state=encoder_output, attention_mask=attention_mask, past=past + ) + + lm_logits = F.linear(decoder_output, self.final_logits_weight, bias=self.final_logits_bias) + + return lm_logits, past + + def greedy_search( + self, input_ids, encoder_output, attention_mask, max_length, pad_token_id: int, eos_token_id: int + ): + # init sequence length tensors + sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( + input_ids, max_length + ) + + past: List[torch.Tensor] = [] + while cur_len < max_length: + + logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past) + next_token_logits = logits[:, -1, :] + + # pre-process distribution + scores = self.logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = torch.argmax(scores, dim=-1) + + # add code that transfomers next_tokens to tokens_to_add + if eos_token_id is not None: + assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." + next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) + + # add token and increase length by one + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + # update sequence length + if eos_token_id is not None: + sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( + sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id + ) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sequences.max() == 0: + break + + # increase cur_len + cur_len = cur_len + 1 + + return input_ids + + def _prepare_decoder_input_ids_for_generation( + self, + input_ids: torch.LongTensor, + decoder_start_token_id, + bos_token_id: Optional[int] = None, + ) -> torch.LongTensor: + + decoder_input_ids = ( + torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device) + * decoder_start_token_id + ) + return decoder_input_ids + + def forward(self, input_ids, attention_mask, max_length, decoder_start_token_id): + pad_token_id = self.config.pad_token_id + bos_token_id = self.config.bos_token_id + eos_token_id = self.config.eos_token_id + + # special case if pad_token_id is not defined + if pad_token_id is None and eos_token_id is not None: + # Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation. + pad_token_id = eos_token_id + + encoder_output = self._encoder_forward(input_ids, attention_mask) + + input_ids = self._prepare_decoder_input_ids_for_generation( + input_ids, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + ) + + return self.greedy_search( + input_ids, + encoder_output, + attention_mask, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + + +# TorchScript compatible BeamSearchScorer +class BeamSearchScorerTS(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_length: int = 200 + self.num_beams: int = 3 + self.batch_size: int = 1 + self.length_penalty: float = 1.0 + self.do_early_stopping: bool = True + self.num_beam_hyps_to_keep: int = 1 + self.num_beam_groups: int = 1 + self.group_size: int = self.num_beams // self.num_beam_groups + self._done = torch.zeros(self.batch_size, dtype=torch.bool) + self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long) + self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9 + self._beam_hyps_max_length: int = self.max_length - 1 + self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible + self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible + + def is_done(self) -> torch.Tensor: + return self._done.all() + + def init( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: float = 1.0, + do_early_stopping: bool = False, + num_beam_hyps_to_keep: int = 1, + num_beam_groups: int = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.batch_size = batch_size + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups + + # NOTE: TorchScript does not support List of Modules + # Rewritten BeamHypotheses with tensors and list of tensors. + self._done = torch.zeros(batch_size, dtype=torch.bool, device=device) + self._beam_hyps_count = torch.zeros(batch_size, dtype=torch.long, device=device) + self._beam_hyps_worst_scores = torch.zeros(batch_size, device=device) + 1e9 + self._beam_hyps = [] + self._beam_scores = [] + + self._beam_hyps_max_length = max_length - 1 # ignoring bos_token + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " + f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + + def hypo_len(self, hypo_idx: int): + """ + Number of hypotheses in the list. + """ + return self._beam_hyps_count[hypo_idx] + + def hypo_add(self, hyp: torch.Tensor, sum_logprobs: float, hypo_idx: int): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + hyps_count = self.hypo_len(hypo_idx) + if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]: + # NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx. + beam_idx = ( + torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long) + ) + # beam_idx = torch.sum(_beam_hyps_count[:hypo_idx]) + self._beam_scores.insert(beam_idx, torch.tensor([score])) + self._beam_hyps.insert(beam_idx, hyp) + if hyps_count + 1 > self.num_beams: + sorted_next_scores, sorted_indices = torch.topk( + torch.cat(self._beam_scores)[beam_idx : beam_idx + hyps_count + 1], hyps_count + 1, largest=False + ) + del self._beam_hyps[int((sorted_indices[0] + beam_idx))] + del self._beam_scores[int((sorted_indices[0] + beam_idx))] + self._beam_hyps_worst_scores[hypo_idx] = sorted_next_scores[1] + else: + self._beam_hyps_worst_scores[hypo_idx] = min(score, self._beam_hyps_worst_scores[hypo_idx]) + self._beam_hyps_count[hypo_idx] = hyps_count + 1 + + def hypo_is_done(self, hypo_idx: int, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + if self.hypo_len(hypo_idx) < self.num_beams: + return False + elif self.do_early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self._beam_hyps_worst_scores[hypo_idx].item() >= cur_score + return ret + + def process( + self, + input_ids: torch.Tensor, + next_scores: torch.Tensor, + next_tokens: torch.Tensor, + next_indices: torch.Tensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps_count) + assert batch_size == (input_ids.shape[0] // self.group_size) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + + for batch_idx in range(batch_size): + if self._done[batch_idx]: + assert ( + self.hypo_len(batch_idx) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.group_size + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size + if is_beam_token_worse_than_top_num_beams: + continue + self.hypo_add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + batch_idx, + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.group_size: + break + + if beam_idx < self.group_size: + raise ValueError( + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or self.hypo_is_done( + batch_idx, + next_scores[batch_idx].max().item(), + cur_len, + ) + + return next_beam_scores.view(-1), next_beam_tokens.view(-1), next_beam_indices.view(-1) + + def finalize( + self, + input_ids: torch.Tensor, + final_beam_scores: torch.Tensor, + final_beam_tokens: torch.Tensor, + final_beam_indices: torch.Tensor, + pad_token_id: int, + eos_token_id: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = len(self._beam_hyps_count) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx in range(batch_size): + if self._done[batch_idx]: + continue + + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + self.hypo_add(final_tokens, final_score, batch_idx) + + # select the best hypotheses + # NOTE: new is not scriptable + sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long) + best = [] + best_scores = torch.zeros( + batch_size * self.num_beam_hyps_to_keep, device=input_ids.device, dtype=torch.float32 + ) + # retrieve best hypotheses + for i in range(batch_size): + # NOTE: lambda is not scriptable + batch_hypo_start = torch.sum(self._beam_hyps_count[:i]) if i > 0 else torch.tensor(0, dtype=torch.long) + batch_hypo_end = torch.sum(self._beam_hyps_count[: i + 1]) + beam_scores = torch.cat(self._beam_scores)[batch_hypo_start:batch_hypo_end] + sorted_next_scores, sorted_indices = torch.topk(beam_scores, len(beam_scores), largest=True) + for j in range(self.num_beam_hyps_to_keep): + best_score = beam_scores[sorted_indices[j]] + best_hyp = self._beam_hyps[batch_hypo_start + sorted_indices[j]] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + # append to lists + best.append(best_hyp) + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score + + # prepare for adding eos + sent_max_len = min(sent_lengths.max() + 1, self.max_length) + decoded = torch.zeros(batch_size * self.num_beam_hyps_to_keep, sent_max_len, dtype=torch.long) + # shorter batches are padded if needed + if sent_lengths.min() != sent_lengths.max(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < self.max_length: + decoded[i, sent_lengths[i]] = eos_token_id + + return decoded, best_scores + + +class BARTBeamSearchGenerator(BARTGenerator): + def __init__(self, model): + super().__init__(model) + self.beam_scorer = BeamSearchScorerTS() + self.device = model.device + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + last_hidden_state: torch.Tensor, + expand_size: int = 1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + attention_mask = attention_mask.index_select(0, expanded_return_idx) + + last_hidden_state = last_hidden_state.index_select(0, expanded_return_idx.to(last_hidden_state.device)) + return input_ids, attention_mask, last_hidden_state + + def adjust_logits_during_generation(self, logits, cur_len: int, max_length: int): + if cur_len == 1 and self.config.force_bos_token_to_be_generated: + logits = self._force_token_id_to_be_generated(logits, self.config.bos_token_id) + elif cur_len == max_length - 1 and self.config.eos_token_id is not None: + logits = self._force_token_id_to_be_generated(logits, self.config.eos_token_id) + return logits + + @staticmethod + def _force_token_id_to_be_generated(scores, token_id: int): + """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" + mask = torch.full_like(scores, 1, dtype=torch.bool) + mask[:, token_id] = False + return scores.masked_fill(mask, -float("inf")) + + def _reorder_cache(self, past: List[torch.Tensor], beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + reordered_decoder_past = [] + for state in past: + reordered_decoder_past.append(state.index_select(0, beam_idx)) + return reordered_decoder_past + + def beam_search( + self, input_ids, encoder_output, attention_mask, num_beams, max_length, pad_token_id: int, eos_token_id: int + ): + + batch_size = self.beam_scorer.batch_size + + num_beams = self.beam_scorer.num_beams + batch_beam_size, cur_len = input_ids.shape + + assert ( + num_beams * batch_size == batch_beam_size + ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + next_tokens = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device) + next_indices = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device) + + past: List[torch.Tensor] = [] + while cur_len < max_length: + logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past) + next_token_logits = logits[:, -1, :] + + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + + # pre-process distribution + next_token_scores = self.logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + + beam_scores, beam_next_tokens, beam_idx = self.beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + cur_len = cur_len + 1 + + if len(past) > 0: + past = self._reorder_cache(past, beam_idx) + + if self.beam_scorer.is_done(): + break + + sequences, sequence_scores = self.beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + + return sequences + + def forward(self, input_ids, attention_mask, num_beams, max_length, decoder_start_token_id): + pad_token_id = self.config.pad_token_id + bos_token_id = self.config.bos_token_id + eos_token_id = self.config.eos_token_id + + # special case if pad_token_id is not defined + if pad_token_id is None and eos_token_id is not None: + # logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + + encoder_output = self._encoder_forward(input_ids, attention_mask) + + input_ids = self._prepare_decoder_input_ids_for_generation( + input_ids, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + ) + + # from generation_utils.py + batch_size = input_ids.shape[0] + + length_penalty = self.config.length_penalty + num_return_sequences = self.config.num_return_sequences + early_stopping = True + + self.beam_scorer.init( + batch_size=batch_size, + max_length=max_length, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + + input_ids, attention_mask, encoder_output = self._expand_inputs_for_generation( + input_ids, + attention_mask, + encoder_output, + expand_size=num_beams, + ) + + return self.beam_search( + input_ids=input_ids, + encoder_output=encoder_output, + attention_mask=attention_mask, + num_beams=num_beams, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) diff --git a/examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py b/examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py new file mode 100644 index 0000000000..d16c1e4c41 --- /dev/null +++ b/examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py @@ -0,0 +1,113 @@ +import os + +import numpy + +import onnx + + +def is_equal_tensor_proto(a, b): + name_a = a.name + name_b = b.name + + a.name = "" + b.name = "" + + res = a == b + + a.name = name_a + b.name = name_b + + return res + + +def node_replace_input_with(node_proto, name, new_name): + for i, input_name in enumerate(node_proto.input): + if input_name == name: + node_proto.input.insert(i, new_name) + node_proto.input.pop(i + 1) + + if node_proto.op_type == "If": + graph_replace_input_with(node_proto.attribute[0].g, name, new_name) + graph_replace_input_with(node_proto.attribute[1].g, name, new_name) + if node_proto.op_type == "Loop": + graph_replace_input_with(node_proto.attribute[0].g, name, new_name) + + +def graph_replace_input_with(graph_proto, name, new_name): + for n in graph_proto.node: + node_replace_input_with(n, name, new_name) + + +def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace): + inits_with_data = [i for i in model.graph.initializer] + inits = [i for i in model_without_ext.graph.initializer] + for i, ref_i in ind_to_replace: + assert inits_with_data[i].name == inits[i].name + assert inits_with_data[ref_i].name == inits[ref_i].name + assert i > ref_i + + name_i = inits[i].name + name_ref = inits[ref_i].name + + model_without_ext.graph.initializer.remove(inits[i]) + + # for n in model.graph.node: + graph_replace_input_with(model_without_ext.graph, name_i, name_ref) + + +def remove_dup_initializers(onnx_file_path): + model_file_folder = os.path.dirname(onnx_file_path) + model_file_name = os.path.basename(onnx_file_path) + + model = onnx.load(os.path.join(model_file_folder, model_file_name)) + + inits = [i for i in model.graph.initializer] + + dup_set = set() + dup_map = {} + ind_to_replace = [] + + total_reduced_size = 0 + + for i in range(len(inits)): + if i in dup_set: + continue + + for j in range(i + 1, len(inits)): + if j in dup_set: + continue + if is_equal_tensor_proto(inits[i], inits[j]): + dup_set.add(i) + dup_set.add(j) + + dtype = inits[j].data_type + mem_size = numpy.prod(inits[j].dims) + if dtype == 1: + mem_size *= 4 + elif dtype == 6: + mem_size *= 4 + elif dtype == 7 or dtype == 11: + mem_size *= 8 + else: + print("unexpected data type: ", dtype) + total_reduced_size += mem_size + + name_i = inits[i].name + name_j = inits[j].name + + if name_i in dup_map: + dup_map[name_i].append(name_j) + else: + dup_map[name_i] = [name_j] + ind_to_replace.append((j, i)) + + print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB") + + ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0]) + remove_dup_initializers_from_model(model, model, ind_to_replace) + + optimized_model_file_name = "optimized_" + model_file_name + new_model = os.path.join(model_file_folder, optimized_model_file_name) + onnx.save(model, new_model) + + return new_model diff --git a/examples/onnx/pytorch/translation/requirements.txt b/examples/onnx/pytorch/translation/requirements.txt new file mode 100644 index 0000000000..5714ddd648 --- /dev/null +++ b/examples/onnx/pytorch/translation/requirements.txt @@ -0,0 +1 @@ +torch >= 1.8 \ No newline at end of file diff --git a/examples/onnx/pytorch/translation/run_onnx_exporter.py b/examples/onnx/pytorch/translation/run_onnx_exporter.py new file mode 100644 index 0000000000..1355903df6 --- /dev/null +++ b/examples/onnx/pytorch/translation/run_onnx_exporter.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +""" +import argparse +import logging +import os +import sys +from datetime import datetime + +import numpy as np +import torch + +import onnxruntime +import transformers +from bart_onnx.generation_onnx import BARTBeamSearchGenerator +from bart_onnx.reduce_onnx_size import remove_dup_initializers +from transformers import BartForConditionalGeneration, BartTokenizer + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) + +logger = logging.getLogger(__name__) + +model_dict = {"facebook/bart-base": BartForConditionalGeneration} +tokenizer_dict = {"facebook/bart-base": BartTokenizer} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--max_length", + type=int, + default=5, + help=("The maximum total input sequence length after tokenization."), + ) + parser.add_argument( + "--num_beams", + type=int, + default=None, + help="Number of beams to use for evaluation. This argument will be " + "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Device where the model will be run", + ) + parser.add_argument("--output_file_path", type=str, default=None, help="Where to store the final ONNX file.") + + args = parser.parse_args() + + return args + + +def load_model_tokenizer(model_name, device="cpu"): + huggingface_model = model_dict[model_name].from_pretrained(model_name).to(device) + tokenizer = tokenizer_dict[model_name].from_pretrained(model_name) + + if model_name in ["facebook/bart-base"]: + huggingface_model.config.no_repeat_ngram_size = 0 + huggingface_model.config.forced_bos_token_id = None + huggingface_model.config.min_length = 0 + + return huggingface_model, tokenizer + + +def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_length): + model.eval() + + ort_sess = None + onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model)) + + with torch.no_grad(): + ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device) + + # Test export here. + summary_ids = model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + num_beams=num_beams, + max_length=max_length, + early_stopping=True, + decoder_start_token_id=model.config.decoder_start_token_id, + ) + + if not ort_sess: + torch.onnx.export( + onnx_bart, + ( + inputs["input_ids"], + inputs["attention_mask"], + num_beams, + max_length, + model.config.decoder_start_token_id, + ), + onnx_file_path, + opset_version=14, + input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"], + output_names=["output_ids"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "seq"}, + "output_ids": {0: "batch", 1: "seq_out"}, + }, + verbose=False, + strip_doc_string=False, + example_outputs=summary_ids, + ) + + new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path)) + + ort_sess = onnxruntime.InferenceSession(new_onnx_file_path) + ort_out = ort_sess.run( + None, + { + "input_ids": inputs["input_ids"].cpu().numpy(), + "attention_mask": inputs["attention_mask"].cpu().numpy(), + "num_beams": np.array(num_beams), + "max_length": np.array(max_length), + "decoder_start_token_id": np.array(model.config.decoder_start_token_id), + }, + ) + + np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3) + + print("========= Pass - Results are matched! =========") + + +def main(): + args = parse_args() + local_device = None + local_max_length = 5 + local_num_beams = 4 + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.setLevel(logging.ERROR) + transformers.utils.logging.set_verbosity_error() + + if args.model_name_or_path: + model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device) + else: + raise ValueError("Make sure that model name has been passed") + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + if args.device: + if args.device == "cuda" and not torch.cuda.is_available(): + raise ValueError("CUDA is not available in this server.") + + local_device = torch.device(args.device) + else: + local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + model.to(local_device) + + if args.max_length: + local_max_length = args.max_length + + if args.num_beams: + local_num_beams = args.num_beams + + if args.output_file_path: + output_name = args.output_file_path + else: + output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond) + + export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length) + + logger.info("***** Running export *****") + + +if __name__ == "__main__": + main()