Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. (#14310)
* Update code to resolve comments left in previous PR. * Add README.md file for this example. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update README.md file to resolve comments. * Add a section name. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel <garymm@garymm.org> * Add more comments for _convert_past_list_to_tuple(). * Change the default file name to a consistent one. * Fix a format issue. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel <garymm@garymm.org> * Update examples/onnx/pytorch/translation/run_onnx_exporter.py Co-authored-by: Gary Miguel <garymm@garymm.org> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Change the folder to summarization and address some other coments. * Update the torch version. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Gary Miguel <garymm@garymm.org> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
43
examples/onnx/pytorch/summarization/README.md
Normal file
43
examples/onnx/pytorch/summarization/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
<!---
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Bart + Beam Search to ONNX
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.
|
||||||
|
|
||||||
|
Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible.
|
||||||
|
|
||||||
|
|
||||||
|
## How to run the example
|
||||||
|
|
||||||
|
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/huggingface/transformers
|
||||||
|
cd transformers
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
Then cd in this example folder and run
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can run the example command below to get the example ONNX file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
|
||||||
|
```
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import itertools
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -8,23 +9,23 @@ from transformers import BartConfig
|
|||||||
from transformers.generation_utils import GenerationMixin
|
from transformers.generation_utils import GenerationMixin
|
||||||
|
|
||||||
|
|
||||||
def flatten_list(past):
|
def _convert_past_list_to_tuple(past_key_values):
|
||||||
values = []
|
"""
|
||||||
if past is not None:
|
In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not
|
||||||
for i, p in enumerate(past):
|
TorchScript-compatible. To support this, we have to convert it during the export process.
|
||||||
for j, q in enumerate(p):
|
This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for
|
||||||
values.append(q)
|
the inner decoder.
|
||||||
|
|
||||||
return values
|
According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
|
||||||
|
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
|
||||||
|
"""
|
||||||
def list_to_tuple(past):
|
count_of_each_inner_tuple = 4
|
||||||
results = ()
|
results = ()
|
||||||
temp_result = ()
|
temp_result = ()
|
||||||
count_n = len(past) // 4
|
count_n = len(past_key_values) // count_of_each_inner_tuple
|
||||||
for idx in range(count_n):
|
for idx in range(count_n):
|
||||||
real_idx = idx * 4
|
real_idx = idx * count_of_each_inner_tuple
|
||||||
temp_result = tuple(past[real_idx : real_idx + 4])
|
temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple])
|
||||||
results += ((temp_result),)
|
results += ((temp_result),)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -51,7 +52,7 @@ class DecoderForONNX(torch.nn.Module):
|
|||||||
def forward(self, input_ids, encoder_state, attention_mask, past=None):
|
def forward(self, input_ids, encoder_state, attention_mask, past=None):
|
||||||
all_results = None
|
all_results = None
|
||||||
if past is not None:
|
if past is not None:
|
||||||
all_results = list_to_tuple(past)
|
all_results = _convert_past_list_to_tuple(past)
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
last_hidden_state, past_key_values = self.decoder(
|
last_hidden_state, past_key_values = self.decoder(
|
||||||
@@ -68,28 +69,33 @@ class DecoderForONNX(torch.nn.Module):
|
|||||||
return last_hidden_state, past_values
|
return last_hidden_state, past_values
|
||||||
|
|
||||||
|
|
||||||
def create_traced_encoder(encoder, input_ids, attention_mask):
|
def _create_traced_encoder(encoder, input_ids, attention_mask):
|
||||||
encoder_c = copy.deepcopy(encoder)
|
encoder_c = copy.deepcopy(encoder)
|
||||||
encoder_for_onnx = EncoderForONNX(encoder_c)
|
encoder_for_onnx = EncoderForONNX(encoder_c)
|
||||||
|
|
||||||
# return torch.jit.trace(encoder, (input_ids, attention_mask))
|
|
||||||
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
|
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
|
||||||
|
|
||||||
|
|
||||||
def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
|
def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
|
||||||
decoder_c = copy.deepcopy(decoder)
|
decoder_c = copy.deepcopy(decoder)
|
||||||
decoder_for_onnx = DecoderForONNX(decoder_c)
|
decoder_for_onnx = DecoderForONNX(decoder_c)
|
||||||
past_values = flatten_list(past)
|
past_values = list(itertools.chain.from_iterable(past or ()))
|
||||||
|
|
||||||
# Do this twice so we got 2 different decoders for further work.
|
# Do this twice so we got 2 different decoders for further work.
|
||||||
if past_values is None or len(past_values) == 0:
|
if past_values:
|
||||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
|
|
||||||
else:
|
|
||||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
|
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
|
||||||
|
else:
|
||||||
|
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
|
||||||
|
|
||||||
|
|
||||||
class BartConfigTS(BartConfig, torch.nn.Module):
|
class BartConfigTS(BartConfig, torch.nn.Module):
|
||||||
def init_module(self):
|
"""
|
||||||
|
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
|
||||||
|
TorchScript only supports sub-classes of torch.nn.Module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
BartConfig.__init__(self, config)
|
||||||
torch.nn.Module.__init__(self)
|
torch.nn.Module.__init__(self)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = BartConfigTS(model.config)
|
self.config = BartConfigTS(model.config)
|
||||||
self.config.init_module()
|
|
||||||
self.config.force_bos_token_to_be_generated = False
|
self.config.force_bos_token_to_be_generated = False
|
||||||
self._trace_modules(model)
|
self._trace_modules(model)
|
||||||
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
|
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
|
||||||
@@ -136,7 +141,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
|||||||
self.decoder_layers = model.config.decoder_layers
|
self.decoder_layers = model.config.decoder_layers
|
||||||
|
|
||||||
def _trace_modules(self, model):
|
def _trace_modules(self, model):
|
||||||
# Be aware of the last one 2 should be kept.
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
@@ -200,89 +204,25 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
|||||||
57,
|
57,
|
||||||
8629,
|
8629,
|
||||||
5,
|
5,
|
||||||
2,
|
model.config.eos_token_id,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
device=model.device,
|
device=model.device,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
attention_mask = torch.tensor(
|
attention_mask = torch.tensor(
|
||||||
[
|
[[True] * input_ids.shape[-1]],
|
||||||
[
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
]
|
|
||||||
],
|
|
||||||
device=model.device,
|
device=model.device,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
)
|
)
|
||||||
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
|
self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
|
||||||
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
|
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||||
decoder = model.model.decoder
|
decoder = model.model.decoder
|
||||||
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
|
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
|
||||||
self.decoder_no_past = create_traced_decoder(
|
self.decoder_no_past = _create_traced_decoder(
|
||||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
|
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
|
||||||
)
|
)
|
||||||
self.decoder_with_past = create_traced_decoder(
|
self.decoder_with_past = _create_traced_decoder(
|
||||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
|
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -414,8 +354,8 @@ class BeamSearchScorerTS(torch.nn.Module):
|
|||||||
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
|
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
|
||||||
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
|
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
|
||||||
self._beam_hyps_max_length: int = self.max_length - 1
|
self._beam_hyps_max_length: int = self.max_length - 1
|
||||||
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
|
||||||
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
|
||||||
|
|
||||||
def is_done(self) -> torch.Tensor:
|
def is_done(self) -> torch.Tensor:
|
||||||
return self._done.all()
|
return self._done.all()
|
||||||
@@ -474,11 +414,11 @@ class BeamSearchScorerTS(torch.nn.Module):
|
|||||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||||
hyps_count = self.hypo_len(hypo_idx)
|
hyps_count = self.hypo_len(hypo_idx)
|
||||||
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
|
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
|
||||||
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
|
# NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
|
||||||
|
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
|
||||||
beam_idx = (
|
beam_idx = (
|
||||||
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
|
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
|
||||||
)
|
)
|
||||||
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
|
|
||||||
self._beam_scores.insert(beam_idx, torch.tensor([score]))
|
self._beam_scores.insert(beam_idx, torch.tensor([score]))
|
||||||
self._beam_hyps.insert(beam_idx, hyp)
|
self._beam_hyps.insert(beam_idx, hyp)
|
||||||
if hyps_count + 1 > self.num_beams:
|
if hyps_count + 1 > self.num_beams:
|
||||||
@@ -605,7 +545,7 @@ class BeamSearchScorerTS(torch.nn.Module):
|
|||||||
self.hypo_add(final_tokens, final_score, batch_idx)
|
self.hypo_add(final_tokens, final_score, batch_idx)
|
||||||
|
|
||||||
# select the best hypotheses
|
# select the best hypotheses
|
||||||
# NOTE: new is not scriptable
|
# NOTE: torch.Tensor.new_zeros() is not scriptable
|
||||||
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
|
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
|
||||||
best = []
|
best = []
|
||||||
best_scores = torch.zeros(
|
best_scores = torch.zeros(
|
||||||
@@ -782,7 +722,6 @@ class BARTBeamSearchGenerator(BARTGenerator):
|
|||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# from generation_utils.py
|
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
length_penalty = self.config.length_penalty
|
length_penalty = self.config.length_penalty
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Code to remove duplicate initializers to reduce ONNX model size.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@@ -5,7 +9,7 @@ import numpy
|
|||||||
import onnx
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
def is_equal_tensor_proto(a, b):
|
def _is_equal_tensor_proto(a, b):
|
||||||
name_a = a.name
|
name_a = a.name
|
||||||
name_b = b.name
|
name_b = b.name
|
||||||
|
|
||||||
@@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def node_replace_input_with(node_proto, name, new_name):
|
def _node_replace_input_with(node_proto, name, new_name):
|
||||||
for i, input_name in enumerate(node_proto.input):
|
for i, input_name in enumerate(node_proto.input):
|
||||||
if input_name == name:
|
if input_name == name:
|
||||||
node_proto.input.insert(i, new_name)
|
node_proto.input.insert(i, new_name)
|
||||||
node_proto.input.pop(i + 1)
|
node_proto.input.pop(i + 1)
|
||||||
|
|
||||||
if node_proto.op_type == "If":
|
if node_proto.op_type == "If":
|
||||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||||
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
|
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
|
||||||
if node_proto.op_type == "Loop":
|
if node_proto.op_type == "Loop":
|
||||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||||
|
|
||||||
|
|
||||||
def graph_replace_input_with(graph_proto, name, new_name):
|
def _graph_replace_input_with(graph_proto, name, new_name):
|
||||||
for n in graph_proto.node:
|
for n in graph_proto.node:
|
||||||
node_replace_input_with(n, name, new_name)
|
_node_replace_input_with(n, name, new_name)
|
||||||
|
|
||||||
|
|
||||||
def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
|
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
|
||||||
inits_with_data = [i for i in model.graph.initializer]
|
inits_with_data = [i for i in model.graph.initializer]
|
||||||
inits = [i for i in model_without_ext.graph.initializer]
|
inits = [i for i in model_without_ext.graph.initializer]
|
||||||
for i, ref_i in ind_to_replace:
|
for i, ref_i in ind_to_replace:
|
||||||
@@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
|
|||||||
model_without_ext.graph.initializer.remove(inits[i])
|
model_without_ext.graph.initializer.remove(inits[i])
|
||||||
|
|
||||||
# for n in model.graph.node:
|
# for n in model.graph.node:
|
||||||
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
|
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
|
||||||
|
|
||||||
|
|
||||||
def remove_dup_initializers(onnx_file_path):
|
def remove_dup_initializers(onnx_file_path):
|
||||||
|
"""
|
||||||
|
Removes duplicate initializers from the model to reduce its size.
|
||||||
|
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
|
||||||
|
"""
|
||||||
|
|
||||||
model_file_folder = os.path.dirname(onnx_file_path)
|
model_file_folder = os.path.dirname(onnx_file_path)
|
||||||
model_file_name = os.path.basename(onnx_file_path)
|
model_file_name = os.path.basename(onnx_file_path)
|
||||||
|
|
||||||
@@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
|
|||||||
for j in range(i + 1, len(inits)):
|
for j in range(i + 1, len(inits)):
|
||||||
if j in dup_set:
|
if j in dup_set:
|
||||||
continue
|
continue
|
||||||
if is_equal_tensor_proto(inits[i], inits[j]):
|
if _is_equal_tensor_proto(inits[i], inits[j]):
|
||||||
dup_set.add(i)
|
dup_set.add(i)
|
||||||
dup_set.add(j)
|
dup_set.add(j)
|
||||||
|
|
||||||
@@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):
|
|||||||
|
|
||||||
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
|
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
|
||||||
|
|
||||||
ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
|
ind_to_replace = sorted(ind_to_replace)
|
||||||
remove_dup_initializers_from_model(model, model, ind_to_replace)
|
_remove_dup_initializers_from_model(model, model, ind_to_replace)
|
||||||
|
|
||||||
optimized_model_file_name = "optimized_" + model_file_name
|
optimized_model_file_name = "optimized_" + model_file_name
|
||||||
new_model = os.path.join(model_file_folder, optimized_model_file_name)
|
new_model = os.path.join(model_file_folder, optimized_model_file_name)
|
||||||
1
examples/onnx/pytorch/summarization/requirements.txt
Normal file
1
examples/onnx/pytorch/summarization/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
torch >= 1.10
|
||||||
@@ -20,7 +20,6 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -46,7 +45,7 @@ tokenizer_dict = {"facebook/bart-base": BartTokenizer}
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
@@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
ort_sess = None
|
ort_sess = None
|
||||||
onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model))
|
bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||||
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
|
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
# Test export here.
|
|
||||||
summary_ids = model.generate(
|
summary_ids = model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
@@ -120,9 +118,8 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
|||||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ort_sess:
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
onnx_bart,
|
bart_script_model,
|
||||||
(
|
(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
inputs["attention_mask"],
|
inputs["attention_mask"],
|
||||||
@@ -138,13 +135,15 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
|||||||
"input_ids": {0: "batch", 1: "seq"},
|
"input_ids": {0: "batch", 1: "seq"},
|
||||||
"output_ids": {0: "batch", 1: "seq_out"},
|
"output_ids": {0: "batch", 1: "seq_out"},
|
||||||
},
|
},
|
||||||
verbose=False,
|
|
||||||
strip_doc_string=False,
|
|
||||||
example_outputs=summary_ids,
|
example_outputs=summary_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Model exported to {}".format(onnx_file_path))
|
||||||
|
|
||||||
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
|
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
|
||||||
|
|
||||||
|
logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path))
|
||||||
|
|
||||||
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
|
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
|
||||||
ort_out = ort_sess.run(
|
ort_out = ort_sess.run(
|
||||||
None,
|
None,
|
||||||
@@ -159,14 +158,14 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
|||||||
|
|
||||||
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
|
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
print("========= Pass - Results are matched! =========")
|
logger.info("Model outputs from torch and ONNX Runtime are similar.")
|
||||||
|
logger.info("Success.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
local_device = None
|
max_length = 5
|
||||||
local_max_length = 5
|
num_beams = 4
|
||||||
local_num_beams = 4
|
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -175,41 +174,31 @@ def main():
|
|||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.INFO)
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
if args.model_name_or_path:
|
device = torch.device(args.device)
|
||||||
model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device)
|
|
||||||
else:
|
model, tokenizer = load_model_tokenizer(args.model_name_or_path, device)
|
||||||
raise ValueError("Make sure that model name has been passed")
|
|
||||||
|
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
if args.device:
|
model.to(device)
|
||||||
if args.device == "cuda" and not torch.cuda.is_available():
|
|
||||||
raise ValueError("CUDA is not available in this server.")
|
|
||||||
|
|
||||||
local_device = torch.device(args.device)
|
|
||||||
else:
|
|
||||||
local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
|
|
||||||
model.to(local_device)
|
|
||||||
|
|
||||||
if args.max_length:
|
if args.max_length:
|
||||||
local_max_length = args.max_length
|
max_length = args.max_length
|
||||||
|
|
||||||
if args.num_beams:
|
if args.num_beams:
|
||||||
local_num_beams = args.num_beams
|
num_beams = args.num_beams
|
||||||
|
|
||||||
if args.output_file_path:
|
if args.output_file_path:
|
||||||
output_name = args.output_file_path
|
output_name = args.output_file_path
|
||||||
else:
|
else:
|
||||||
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond)
|
output_name = "BART.onnx"
|
||||||
|
|
||||||
export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length)
|
logger.info("Exporting model to ONNX")
|
||||||
|
export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
|
||||||
logger.info("***** Running export *****")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1 +0,0 @@
|
|||||||
torch >= 1.8
|
|
||||||
Reference in New Issue
Block a user