Add an example of exporting BartModel + BeamSearch to ONNX module. (#13765)
* Add all example files. * Reformat files by black. * Style. * Remove unused imports. Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
817
examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py
Normal file
817
examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py
Normal file
@@ -0,0 +1,817 @@
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import BartConfig
|
||||
from transformers.generation_utils import GenerationMixin
|
||||
|
||||
|
||||
def flatten_list(past):
|
||||
values = []
|
||||
if past is not None:
|
||||
for i, p in enumerate(past):
|
||||
for j, q in enumerate(p):
|
||||
values.append(q)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def list_to_tuple(past):
|
||||
results = ()
|
||||
temp_result = ()
|
||||
count_n = len(past) // 4
|
||||
for idx in range(count_n):
|
||||
real_idx = idx * 4
|
||||
temp_result = tuple(past[real_idx : real_idx + 4])
|
||||
results += ((temp_result),)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class EncoderForONNX(torch.nn.Module):
|
||||
def __init__(self, encoder):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
|
||||
class DecoderForONNX(torch.nn.Module):
|
||||
def __init__(self, decoder):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, input_ids, encoder_state, attention_mask, past=None):
|
||||
all_results = None
|
||||
if past is not None:
|
||||
all_results = list_to_tuple(past)
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
last_hidden_state, past_key_values = self.decoder(
|
||||
input_ids=input_ids,
|
||||
encoder_hidden_states=encoder_state,
|
||||
encoder_attention_mask=attention_mask,
|
||||
past_key_values=all_results,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
past_values = []
|
||||
for past in past_key_values:
|
||||
past_values = past_values + list(past)
|
||||
return last_hidden_state, past_values
|
||||
|
||||
|
||||
def create_traced_encoder(encoder, input_ids, attention_mask):
|
||||
encoder_c = copy.deepcopy(encoder)
|
||||
encoder_for_onnx = EncoderForONNX(encoder_c)
|
||||
|
||||
# return torch.jit.trace(encoder, (input_ids, attention_mask))
|
||||
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
|
||||
|
||||
|
||||
def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
|
||||
decoder_c = copy.deepcopy(decoder)
|
||||
decoder_for_onnx = DecoderForONNX(decoder_c)
|
||||
past_values = flatten_list(past)
|
||||
|
||||
# Do this twice so we got 2 different decoders for further work.
|
||||
if past_values is None or len(past_values) == 0:
|
||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
|
||||
else:
|
||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
|
||||
|
||||
|
||||
class BartConfigTS(BartConfig, torch.nn.Module):
|
||||
def init_module(self):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
|
||||
class MinLengthLogitsProcessorTS(torch.nn.Module):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
||||
|
||||
Args:
|
||||
min_length (:obj:`int`):
|
||||
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: int):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def forward(self, input_ids, scores) -> torch.Tensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len < self.min_length:
|
||||
scores[:, self.eos_token_id] = -float("inf")
|
||||
return scores
|
||||
|
||||
|
||||
class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.config = BartConfigTS(model.config)
|
||||
self.config.init_module()
|
||||
self.config.force_bos_token_to_be_generated = False
|
||||
self._trace_modules(model)
|
||||
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
|
||||
self.final_logits_weight = model.model.shared.weight
|
||||
self.final_logits_bias = model.final_logits_bias
|
||||
self.decoder_layers = model.config.decoder_layers
|
||||
|
||||
def _trace_modules(self, model):
|
||||
# Be aware of the last one 2 should be kept.
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
19,
|
||||
669,
|
||||
18,
|
||||
420,
|
||||
8,
|
||||
664,
|
||||
57,
|
||||
42,
|
||||
8,
|
||||
664,
|
||||
21,
|
||||
3028,
|
||||
195,
|
||||
4445,
|
||||
331,
|
||||
1293,
|
||||
34,
|
||||
21,
|
||||
10,
|
||||
6174,
|
||||
1100,
|
||||
6,
|
||||
69,
|
||||
104,
|
||||
42,
|
||||
32,
|
||||
2621,
|
||||
1638,
|
||||
144,
|
||||
4,
|
||||
6174,
|
||||
558,
|
||||
108,
|
||||
4419,
|
||||
1091,
|
||||
28,
|
||||
4,
|
||||
1668,
|
||||
9,
|
||||
1509,
|
||||
1621,
|
||||
279,
|
||||
35,
|
||||
867,
|
||||
2734,
|
||||
85,
|
||||
11,
|
||||
2216,
|
||||
2734,
|
||||
85,
|
||||
203,
|
||||
2244,
|
||||
7,
|
||||
6,
|
||||
15,
|
||||
8102,
|
||||
7,
|
||||
57,
|
||||
8629,
|
||||
5,
|
||||
2,
|
||||
]
|
||||
],
|
||||
device=model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
],
|
||||
device=model.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
|
||||
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
decoder = model.model.decoder
|
||||
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
|
||||
self.decoder_no_past = create_traced_decoder(
|
||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
|
||||
)
|
||||
self.decoder_with_past = create_traced_decoder(
|
||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
|
||||
)
|
||||
|
||||
def _encoder_forward(self, input_ids, attention_mask):
|
||||
return self.encoder(input_ids, attention_mask)[0]
|
||||
|
||||
@staticmethod
|
||||
def _init_sequence_length_for_generation(
|
||||
input_ids: torch.LongTensor, max_length: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + 1
|
||||
sequence_lengths = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + max_length
|
||||
|
||||
cur_len = input_ids.shape[-1]
|
||||
return sequence_lengths, unfinished_sequences, cur_len
|
||||
|
||||
def _decoder_forward(self, input_ids, encoder_output, attention_mask, past: List[torch.Tensor]):
|
||||
# Update here to use different decoder for different values of past.
|
||||
if past is None or len(past) == 0:
|
||||
decoder_output, past = self.decoder_no_past(
|
||||
input_ids=input_ids, encoder_state=encoder_output, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
decoder_output, past = self.decoder_with_past(
|
||||
input_ids=input_ids, encoder_state=encoder_output, attention_mask=attention_mask, past=past
|
||||
)
|
||||
|
||||
lm_logits = F.linear(decoder_output, self.final_logits_weight, bias=self.final_logits_bias)
|
||||
|
||||
return lm_logits, past
|
||||
|
||||
def greedy_search(
|
||||
self, input_ids, encoder_output, attention_mask, max_length, pad_token_id: int, eos_token_id: int
|
||||
):
|
||||
# init sequence length tensors
|
||||
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
|
||||
input_ids, max_length
|
||||
)
|
||||
|
||||
past: List[torch.Tensor] = []
|
||||
while cur_len < max_length:
|
||||
|
||||
logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
scores = self.logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# argmax
|
||||
next_tokens = torch.argmax(scores, dim=-1)
|
||||
|
||||
# add code that transfomers next_tokens to tokens_to_add
|
||||
if eos_token_id is not None:
|
||||
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
||||
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
|
||||
|
||||
# add token and increase length by one
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
# update sequence length
|
||||
if eos_token_id is not None:
|
||||
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
|
||||
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
|
||||
)
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
return input_ids
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
decoder_start_token_id,
|
||||
bos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
|
||||
decoder_input_ids = (
|
||||
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
|
||||
* decoder_start_token_id
|
||||
)
|
||||
return decoder_input_ids
|
||||
|
||||
def forward(self, input_ids, attention_mask, max_length, decoder_start_token_id):
|
||||
pad_token_id = self.config.pad_token_id
|
||||
bos_token_id = self.config.bos_token_id
|
||||
eos_token_id = self.config.eos_token_id
|
||||
|
||||
# special case if pad_token_id is not defined
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
encoder_output = self._encoder_forward(input_ids, attention_mask)
|
||||
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
)
|
||||
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
encoder_output,
|
||||
attention_mask,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
# TorchScript compatible BeamSearchScorer
|
||||
class BeamSearchScorerTS(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_length: int = 200
|
||||
self.num_beams: int = 3
|
||||
self.batch_size: int = 1
|
||||
self.length_penalty: float = 1.0
|
||||
self.do_early_stopping: bool = True
|
||||
self.num_beam_hyps_to_keep: int = 1
|
||||
self.num_beam_groups: int = 1
|
||||
self.group_size: int = self.num_beams // self.num_beam_groups
|
||||
self._done = torch.zeros(self.batch_size, dtype=torch.bool)
|
||||
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
|
||||
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
|
||||
self._beam_hyps_max_length: int = self.max_length - 1
|
||||
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
||||
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
||||
|
||||
def is_done(self) -> torch.Tensor:
|
||||
return self._done.all()
|
||||
|
||||
def init(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: float = 1.0,
|
||||
do_early_stopping: bool = False,
|
||||
num_beam_hyps_to_keep: int = 1,
|
||||
num_beam_groups: int = 1,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.batch_size = batch_size
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
self.num_beam_groups = num_beam_groups
|
||||
self.group_size = self.num_beams // self.num_beam_groups
|
||||
|
||||
# NOTE: TorchScript does not support List of Modules
|
||||
# Rewritten BeamHypotheses with tensors and list of tensors.
|
||||
self._done = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
||||
self._beam_hyps_count = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
self._beam_hyps_worst_scores = torch.zeros(batch_size, device=device) + 1e9
|
||||
self._beam_hyps = []
|
||||
self._beam_scores = []
|
||||
|
||||
self._beam_hyps_max_length = max_length - 1 # ignoring bos_token
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
|
||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
def hypo_len(self, hypo_idx: int):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return self._beam_hyps_count[hypo_idx]
|
||||
|
||||
def hypo_add(self, hyp: torch.Tensor, sum_logprobs: float, hypo_idx: int):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
hyps_count = self.hypo_len(hypo_idx)
|
||||
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
|
||||
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
|
||||
beam_idx = (
|
||||
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
|
||||
)
|
||||
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
|
||||
self._beam_scores.insert(beam_idx, torch.tensor([score]))
|
||||
self._beam_hyps.insert(beam_idx, hyp)
|
||||
if hyps_count + 1 > self.num_beams:
|
||||
sorted_next_scores, sorted_indices = torch.topk(
|
||||
torch.cat(self._beam_scores)[beam_idx : beam_idx + hyps_count + 1], hyps_count + 1, largest=False
|
||||
)
|
||||
del self._beam_hyps[int((sorted_indices[0] + beam_idx))]
|
||||
del self._beam_scores[int((sorted_indices[0] + beam_idx))]
|
||||
self._beam_hyps_worst_scores[hypo_idx] = sorted_next_scores[1]
|
||||
else:
|
||||
self._beam_hyps_worst_scores[hypo_idx] = min(score, self._beam_hyps_worst_scores[hypo_idx])
|
||||
self._beam_hyps_count[hypo_idx] = hyps_count + 1
|
||||
|
||||
def hypo_is_done(self, hypo_idx: int, best_sum_logprobs: float, cur_len: int) -> bool:
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if self.hypo_len(hypo_idx) < self.num_beams:
|
||||
return False
|
||||
elif self.do_early_stopping:
|
||||
return True
|
||||
else:
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self._beam_hyps_worst_scores[hypo_idx].item() >= cur_score
|
||||
return ret
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
next_scores: torch.Tensor,
|
||||
next_tokens: torch.Tensor,
|
||||
next_indices: torch.Tensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1]
|
||||
batch_size = len(self._beam_hyps_count)
|
||||
assert batch_size == (input_ids.shape[0] // self.group_size)
|
||||
|
||||
device = input_ids.device
|
||||
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
if self._done[batch_idx]:
|
||||
assert (
|
||||
self.hypo_len(batch_idx) >= self.num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
|
||||
assert (
|
||||
eos_token_id is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
# pad the batch
|
||||
next_beam_scores[batch_idx, :] = 0
|
||||
next_beam_tokens[batch_idx, :] = pad_token_id
|
||||
next_beam_indices[batch_idx, :] = 0
|
||||
continue
|
||||
|
||||
# next tokens for this sentence
|
||||
beam_idx = 0
|
||||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
self.hypo_add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
batch_idx,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beam_scores[batch_idx, beam_idx] = next_score
|
||||
next_beam_tokens[batch_idx, beam_idx] = next_token
|
||||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
||||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.group_size:
|
||||
break
|
||||
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_idx] = self._done[batch_idx] or self.hypo_is_done(
|
||||
batch_idx,
|
||||
next_scores[batch_idx].max().item(),
|
||||
cur_len,
|
||||
)
|
||||
|
||||
return next_beam_scores.view(-1), next_beam_tokens.view(-1), next_beam_indices.view(-1)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
final_beam_scores: torch.Tensor,
|
||||
final_beam_tokens: torch.Tensor,
|
||||
final_beam_indices: torch.Tensor,
|
||||
pad_token_id: int,
|
||||
eos_token_id: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = len(self._beam_hyps_count)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx in range(batch_size):
|
||||
if self._done[batch_idx]:
|
||||
continue
|
||||
|
||||
# all open beam hypotheses are added to the beam hypothesis
|
||||
# beam hypothesis class automatically keeps the best beams
|
||||
for beam_id in range(self.num_beams):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
self.hypo_add(final_tokens, final_score, batch_idx)
|
||||
|
||||
# select the best hypotheses
|
||||
# NOTE: new is not scriptable
|
||||
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
|
||||
best = []
|
||||
best_scores = torch.zeros(
|
||||
batch_size * self.num_beam_hyps_to_keep, device=input_ids.device, dtype=torch.float32
|
||||
)
|
||||
# retrieve best hypotheses
|
||||
for i in range(batch_size):
|
||||
# NOTE: lambda is not scriptable
|
||||
batch_hypo_start = torch.sum(self._beam_hyps_count[:i]) if i > 0 else torch.tensor(0, dtype=torch.long)
|
||||
batch_hypo_end = torch.sum(self._beam_hyps_count[: i + 1])
|
||||
beam_scores = torch.cat(self._beam_scores)[batch_hypo_start:batch_hypo_end]
|
||||
sorted_next_scores, sorted_indices = torch.topk(beam_scores, len(beam_scores), largest=True)
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_score = beam_scores[sorted_indices[j]]
|
||||
best_hyp = self._beam_hyps[batch_hypo_start + sorted_indices[j]]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
# append to lists
|
||||
best.append(best_hyp)
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max() + 1, self.max_length)
|
||||
decoded = torch.zeros(batch_size * self.num_beam_hyps_to_keep, sent_max_len, dtype=torch.long)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min() != sent_lengths.max():
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < self.max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
return decoded, best_scores
|
||||
|
||||
|
||||
class BARTBeamSearchGenerator(BARTGenerator):
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
self.beam_scorer = BeamSearchScorerTS()
|
||||
self.device = model.device
|
||||
|
||||
@staticmethod
|
||||
def _expand_inputs_for_generation(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
last_hidden_state: torch.Tensor,
|
||||
expand_size: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
expanded_return_idx = (
|
||||
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
|
||||
)
|
||||
input_ids = input_ids.index_select(0, expanded_return_idx)
|
||||
|
||||
attention_mask = attention_mask.index_select(0, expanded_return_idx)
|
||||
|
||||
last_hidden_state = last_hidden_state.index_select(0, expanded_return_idx.to(last_hidden_state.device))
|
||||
return input_ids, attention_mask, last_hidden_state
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len: int, max_length: int):
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
logits = self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
logits = self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id: int):
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
mask = torch.full_like(scores, 1, dtype=torch.bool)
|
||||
mask[:, token_id] = False
|
||||
return scores.masked_fill(mask, -float("inf"))
|
||||
|
||||
def _reorder_cache(self, past: List[torch.Tensor], beam_idx):
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
reordered_decoder_past = []
|
||||
for state in past:
|
||||
reordered_decoder_past.append(state.index_select(0, beam_idx))
|
||||
return reordered_decoder_past
|
||||
|
||||
def beam_search(
|
||||
self, input_ids, encoder_output, attention_mask, num_beams, max_length, pad_token_id: int, eos_token_id: int
|
||||
):
|
||||
|
||||
batch_size = self.beam_scorer.batch_size
|
||||
|
||||
num_beams = self.beam_scorer.num_beams
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
next_tokens = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device)
|
||||
next_indices = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device)
|
||||
|
||||
past: List[torch.Tensor] = []
|
||||
while cur_len < max_length:
|
||||
logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
|
||||
# adjust tokens for Bart, *e.g.*
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
|
||||
# reshape for beam search
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = next_tokens // vocab_size
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
beam_scores, beam_next_tokens, beam_idx = self.beam_scorer.process(
|
||||
input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if len(past) > 0:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
if self.beam_scorer.is_done():
|
||||
break
|
||||
|
||||
sequences, sequence_scores = self.beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
return sequences
|
||||
|
||||
def forward(self, input_ids, attention_mask, num_beams, max_length, decoder_start_token_id):
|
||||
pad_token_id = self.config.pad_token_id
|
||||
bos_token_id = self.config.bos_token_id
|
||||
eos_token_id = self.config.eos_token_id
|
||||
|
||||
# special case if pad_token_id is not defined
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
encoder_output = self._encoder_forward(input_ids, attention_mask)
|
||||
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
)
|
||||
|
||||
# from generation_utils.py
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = self.config.length_penalty
|
||||
num_return_sequences = self.config.num_return_sequences
|
||||
early_stopping = True
|
||||
|
||||
self.beam_scorer.init(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
|
||||
input_ids, attention_mask, encoder_output = self._expand_inputs_for_generation(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_output,
|
||||
expand_size=num_beams,
|
||||
)
|
||||
|
||||
return self.beam_search(
|
||||
input_ids=input_ids,
|
||||
encoder_output=encoder_output,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
113
examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py
Normal file
113
examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
|
||||
import numpy
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def is_equal_tensor_proto(a, b):
|
||||
name_a = a.name
|
||||
name_b = b.name
|
||||
|
||||
a.name = ""
|
||||
b.name = ""
|
||||
|
||||
res = a == b
|
||||
|
||||
a.name = name_a
|
||||
b.name = name_b
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def node_replace_input_with(node_proto, name, new_name):
|
||||
for i, input_name in enumerate(node_proto.input):
|
||||
if input_name == name:
|
||||
node_proto.input.insert(i, new_name)
|
||||
node_proto.input.pop(i + 1)
|
||||
|
||||
if node_proto.op_type == "If":
|
||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
|
||||
if node_proto.op_type == "Loop":
|
||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
|
||||
|
||||
def graph_replace_input_with(graph_proto, name, new_name):
|
||||
for n in graph_proto.node:
|
||||
node_replace_input_with(n, name, new_name)
|
||||
|
||||
|
||||
def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
|
||||
inits_with_data = [i for i in model.graph.initializer]
|
||||
inits = [i for i in model_without_ext.graph.initializer]
|
||||
for i, ref_i in ind_to_replace:
|
||||
assert inits_with_data[i].name == inits[i].name
|
||||
assert inits_with_data[ref_i].name == inits[ref_i].name
|
||||
assert i > ref_i
|
||||
|
||||
name_i = inits[i].name
|
||||
name_ref = inits[ref_i].name
|
||||
|
||||
model_without_ext.graph.initializer.remove(inits[i])
|
||||
|
||||
# for n in model.graph.node:
|
||||
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
|
||||
|
||||
|
||||
def remove_dup_initializers(onnx_file_path):
|
||||
model_file_folder = os.path.dirname(onnx_file_path)
|
||||
model_file_name = os.path.basename(onnx_file_path)
|
||||
|
||||
model = onnx.load(os.path.join(model_file_folder, model_file_name))
|
||||
|
||||
inits = [i for i in model.graph.initializer]
|
||||
|
||||
dup_set = set()
|
||||
dup_map = {}
|
||||
ind_to_replace = []
|
||||
|
||||
total_reduced_size = 0
|
||||
|
||||
for i in range(len(inits)):
|
||||
if i in dup_set:
|
||||
continue
|
||||
|
||||
for j in range(i + 1, len(inits)):
|
||||
if j in dup_set:
|
||||
continue
|
||||
if is_equal_tensor_proto(inits[i], inits[j]):
|
||||
dup_set.add(i)
|
||||
dup_set.add(j)
|
||||
|
||||
dtype = inits[j].data_type
|
||||
mem_size = numpy.prod(inits[j].dims)
|
||||
if dtype == 1:
|
||||
mem_size *= 4
|
||||
elif dtype == 6:
|
||||
mem_size *= 4
|
||||
elif dtype == 7 or dtype == 11:
|
||||
mem_size *= 8
|
||||
else:
|
||||
print("unexpected data type: ", dtype)
|
||||
total_reduced_size += mem_size
|
||||
|
||||
name_i = inits[i].name
|
||||
name_j = inits[j].name
|
||||
|
||||
if name_i in dup_map:
|
||||
dup_map[name_i].append(name_j)
|
||||
else:
|
||||
dup_map[name_i] = [name_j]
|
||||
ind_to_replace.append((j, i))
|
||||
|
||||
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
|
||||
|
||||
ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
|
||||
remove_dup_initializers_from_model(model, model, ind_to_replace)
|
||||
|
||||
optimized_model_file_name = "optimized_" + model_file_name
|
||||
new_model = os.path.join(model_file_folder, optimized_model_file_name)
|
||||
onnx.save(model, new_model)
|
||||
|
||||
return new_model
|
||||
1
examples/onnx/pytorch/translation/requirements.txt
Normal file
1
examples/onnx/pytorch/translation/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
torch >= 1.8
|
||||
216
examples/onnx/pytorch/translation/run_onnx_exporter.py
Normal file
216
examples/onnx/pytorch/translation/run_onnx_exporter.py
Normal file
@@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import onnxruntime
|
||||
import transformers
|
||||
from bart_onnx.generation_onnx import BARTBeamSearchGenerator
|
||||
from bart_onnx.reduce_onnx_size import remove_dup_initializers
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||
stream=sys.stdout,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_dict = {"facebook/bart-base": BartForConditionalGeneration}
|
||||
tokenizer_dict = {"facebook/bart-base": BartTokenizer}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=5,
|
||||
help=("The maximum total input sequence length after tokenization."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of beams to use for evaluation. This argument will be "
|
||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Device where the model will be run",
|
||||
)
|
||||
parser.add_argument("--output_file_path", type=str, default=None, help="Where to store the final ONNX file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_model_tokenizer(model_name, device="cpu"):
|
||||
huggingface_model = model_dict[model_name].from_pretrained(model_name).to(device)
|
||||
tokenizer = tokenizer_dict[model_name].from_pretrained(model_name)
|
||||
|
||||
if model_name in ["facebook/bart-base"]:
|
||||
huggingface_model.config.no_repeat_ngram_size = 0
|
||||
huggingface_model.config.forced_bos_token_id = None
|
||||
huggingface_model.config.min_length = 0
|
||||
|
||||
return huggingface_model, tokenizer
|
||||
|
||||
|
||||
def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_length):
|
||||
model.eval()
|
||||
|
||||
ort_sess = None
|
||||
onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model))
|
||||
|
||||
with torch.no_grad():
|
||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
|
||||
|
||||
# Test export here.
|
||||
summary_ids = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||
)
|
||||
|
||||
if not ort_sess:
|
||||
torch.onnx.export(
|
||||
onnx_bart,
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
num_beams,
|
||||
max_length,
|
||||
model.config.decoder_start_token_id,
|
||||
),
|
||||
onnx_file_path,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
|
||||
output_names=["output_ids"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "seq"},
|
||||
"output_ids": {0: "batch", 1: "seq_out"},
|
||||
},
|
||||
verbose=False,
|
||||
strip_doc_string=False,
|
||||
example_outputs=summary_ids,
|
||||
)
|
||||
|
||||
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
|
||||
|
||||
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
|
||||
ort_out = ort_sess.run(
|
||||
None,
|
||||
{
|
||||
"input_ids": inputs["input_ids"].cpu().numpy(),
|
||||
"attention_mask": inputs["attention_mask"].cpu().numpy(),
|
||||
"num_beams": np.array(num_beams),
|
||||
"max_length": np.array(max_length),
|
||||
"decoder_start_token_id": np.array(model.config.decoder_start_token_id),
|
||||
},
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
|
||||
|
||||
print("========= Pass - Results are matched! =========")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
local_device = None
|
||||
local_max_length = 5
|
||||
local_num_beams = 4
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
logger.setLevel(logging.ERROR)
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
if args.model_name_or_path:
|
||||
model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device)
|
||||
else:
|
||||
raise ValueError("Make sure that model name has been passed")
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
if args.device:
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise ValueError("CUDA is not available in this server.")
|
||||
|
||||
local_device = torch.device(args.device)
|
||||
else:
|
||||
local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
model.to(local_device)
|
||||
|
||||
if args.max_length:
|
||||
local_max_length = args.max_length
|
||||
|
||||
if args.num_beams:
|
||||
local_num_beams = args.num_beams
|
||||
|
||||
if args.output_file_path:
|
||||
output_name = args.output_file_path
|
||||
else:
|
||||
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond)
|
||||
|
||||
export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length)
|
||||
|
||||
logger.info("***** Running export *****")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user