Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. (#14310)
* Update code to resolve comments left in previous PR. * Add README.md file for this example. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update README.md file to resolve comments. * Add a section name. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel <garymm@garymm.org> * Add more comments for _convert_past_list_to_tuple(). * Change the default file name to a consistent one. * Fix a format issue. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel <garymm@garymm.org> * Update examples/onnx/pytorch/translation/run_onnx_exporter.py Co-authored-by: Gary Miguel <garymm@garymm.org> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Change the folder to summarization and address some other coments. * Update the torch version. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Gary Miguel <garymm@garymm.org> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
43
examples/onnx/pytorch/summarization/README.md
Normal file
43
examples/onnx/pytorch/summarization/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Bart + Beam Search to ONNX
|
||||
|
||||
|
||||
|
||||
This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.
|
||||
|
||||
Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible.
|
||||
|
||||
|
||||
## How to run the example
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
Then cd in this example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Now you can run the example command below to get the example ONNX file:
|
||||
|
||||
```bash
|
||||
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
|
||||
```
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import itertools
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -8,23 +9,23 @@ from transformers import BartConfig
|
||||
from transformers.generation_utils import GenerationMixin
|
||||
|
||||
|
||||
def flatten_list(past):
|
||||
values = []
|
||||
if past is not None:
|
||||
for i, p in enumerate(past):
|
||||
for j, q in enumerate(p):
|
||||
values.append(q)
|
||||
def _convert_past_list_to_tuple(past_key_values):
|
||||
"""
|
||||
In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not
|
||||
TorchScript-compatible. To support this, we have to convert it during the export process.
|
||||
This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for
|
||||
the inner decoder.
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def list_to_tuple(past):
|
||||
According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
|
||||
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
|
||||
"""
|
||||
count_of_each_inner_tuple = 4
|
||||
results = ()
|
||||
temp_result = ()
|
||||
count_n = len(past) // 4
|
||||
count_n = len(past_key_values) // count_of_each_inner_tuple
|
||||
for idx in range(count_n):
|
||||
real_idx = idx * 4
|
||||
temp_result = tuple(past[real_idx : real_idx + 4])
|
||||
real_idx = idx * count_of_each_inner_tuple
|
||||
temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple])
|
||||
results += ((temp_result),)
|
||||
|
||||
return results
|
||||
@@ -51,7 +52,7 @@ class DecoderForONNX(torch.nn.Module):
|
||||
def forward(self, input_ids, encoder_state, attention_mask, past=None):
|
||||
all_results = None
|
||||
if past is not None:
|
||||
all_results = list_to_tuple(past)
|
||||
all_results = _convert_past_list_to_tuple(past)
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
last_hidden_state, past_key_values = self.decoder(
|
||||
@@ -68,28 +69,33 @@ class DecoderForONNX(torch.nn.Module):
|
||||
return last_hidden_state, past_values
|
||||
|
||||
|
||||
def create_traced_encoder(encoder, input_ids, attention_mask):
|
||||
def _create_traced_encoder(encoder, input_ids, attention_mask):
|
||||
encoder_c = copy.deepcopy(encoder)
|
||||
encoder_for_onnx = EncoderForONNX(encoder_c)
|
||||
|
||||
# return torch.jit.trace(encoder, (input_ids, attention_mask))
|
||||
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
|
||||
|
||||
|
||||
def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
|
||||
def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
|
||||
decoder_c = copy.deepcopy(decoder)
|
||||
decoder_for_onnx = DecoderForONNX(decoder_c)
|
||||
past_values = flatten_list(past)
|
||||
past_values = list(itertools.chain.from_iterable(past or ()))
|
||||
|
||||
# Do this twice so we got 2 different decoders for further work.
|
||||
if past_values is None or len(past_values) == 0:
|
||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
|
||||
else:
|
||||
if past_values:
|
||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
|
||||
else:
|
||||
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
|
||||
|
||||
|
||||
class BartConfigTS(BartConfig, torch.nn.Module):
|
||||
def init_module(self):
|
||||
"""
|
||||
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
|
||||
TorchScript only supports sub-classes of torch.nn.Module.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
BartConfig.__init__(self, config)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
|
||||
@@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.config = BartConfigTS(model.config)
|
||||
self.config.init_module()
|
||||
self.config.force_bos_token_to_be_generated = False
|
||||
self._trace_modules(model)
|
||||
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
|
||||
@@ -136,7 +141,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
self.decoder_layers = model.config.decoder_layers
|
||||
|
||||
def _trace_modules(self, model):
|
||||
# Be aware of the last one 2 should be kept.
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
@@ -200,89 +204,25 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
|
||||
57,
|
||||
8629,
|
||||
5,
|
||||
2,
|
||||
model.config.eos_token_id,
|
||||
]
|
||||
],
|
||||
device=model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
],
|
||||
[[True] * input_ids.shape[-1]],
|
||||
device=model.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
|
||||
self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
|
||||
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
decoder = model.model.decoder
|
||||
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
|
||||
self.decoder_no_past = create_traced_decoder(
|
||||
self.decoder_no_past = _create_traced_decoder(
|
||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
|
||||
)
|
||||
self.decoder_with_past = create_traced_decoder(
|
||||
self.decoder_with_past = _create_traced_decoder(
|
||||
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
|
||||
)
|
||||
|
||||
@@ -414,8 +354,8 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
|
||||
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
|
||||
self._beam_hyps_max_length: int = self.max_length - 1
|
||||
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
||||
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
|
||||
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
|
||||
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
|
||||
|
||||
def is_done(self) -> torch.Tensor:
|
||||
return self._done.all()
|
||||
@@ -474,11 +414,11 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
hyps_count = self.hypo_len(hypo_idx)
|
||||
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
|
||||
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
|
||||
# NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
|
||||
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
|
||||
beam_idx = (
|
||||
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
|
||||
)
|
||||
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
|
||||
self._beam_scores.insert(beam_idx, torch.tensor([score]))
|
||||
self._beam_hyps.insert(beam_idx, hyp)
|
||||
if hyps_count + 1 > self.num_beams:
|
||||
@@ -605,7 +545,7 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
self.hypo_add(final_tokens, final_score, batch_idx)
|
||||
|
||||
# select the best hypotheses
|
||||
# NOTE: new is not scriptable
|
||||
# NOTE: torch.Tensor.new_zeros() is not scriptable
|
||||
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
|
||||
best = []
|
||||
best_scores = torch.zeros(
|
||||
@@ -782,7 +722,6 @@ class BARTBeamSearchGenerator(BARTGenerator):
|
||||
bos_token_id=bos_token_id,
|
||||
)
|
||||
|
||||
# from generation_utils.py
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = self.config.length_penalty
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Code to remove duplicate initializers to reduce ONNX model size.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy
|
||||
@@ -5,7 +9,7 @@ import numpy
|
||||
import onnx
|
||||
|
||||
|
||||
def is_equal_tensor_proto(a, b):
|
||||
def _is_equal_tensor_proto(a, b):
|
||||
name_a = a.name
|
||||
name_b = b.name
|
||||
|
||||
@@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
|
||||
return res
|
||||
|
||||
|
||||
def node_replace_input_with(node_proto, name, new_name):
|
||||
def _node_replace_input_with(node_proto, name, new_name):
|
||||
for i, input_name in enumerate(node_proto.input):
|
||||
if input_name == name:
|
||||
node_proto.input.insert(i, new_name)
|
||||
node_proto.input.pop(i + 1)
|
||||
|
||||
if node_proto.op_type == "If":
|
||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
|
||||
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
|
||||
if node_proto.op_type == "Loop":
|
||||
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
|
||||
|
||||
|
||||
def graph_replace_input_with(graph_proto, name, new_name):
|
||||
def _graph_replace_input_with(graph_proto, name, new_name):
|
||||
for n in graph_proto.node:
|
||||
node_replace_input_with(n, name, new_name)
|
||||
_node_replace_input_with(n, name, new_name)
|
||||
|
||||
|
||||
def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
|
||||
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
|
||||
inits_with_data = [i for i in model.graph.initializer]
|
||||
inits = [i for i in model_without_ext.graph.initializer]
|
||||
for i, ref_i in ind_to_replace:
|
||||
@@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
|
||||
model_without_ext.graph.initializer.remove(inits[i])
|
||||
|
||||
# for n in model.graph.node:
|
||||
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
|
||||
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
|
||||
|
||||
|
||||
def remove_dup_initializers(onnx_file_path):
|
||||
"""
|
||||
Removes duplicate initializers from the model to reduce its size.
|
||||
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
|
||||
"""
|
||||
|
||||
model_file_folder = os.path.dirname(onnx_file_path)
|
||||
model_file_name = os.path.basename(onnx_file_path)
|
||||
|
||||
@@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
|
||||
for j in range(i + 1, len(inits)):
|
||||
if j in dup_set:
|
||||
continue
|
||||
if is_equal_tensor_proto(inits[i], inits[j]):
|
||||
if _is_equal_tensor_proto(inits[i], inits[j]):
|
||||
dup_set.add(i)
|
||||
dup_set.add(j)
|
||||
|
||||
@@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):
|
||||
|
||||
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
|
||||
|
||||
ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
|
||||
remove_dup_initializers_from_model(model, model, ind_to_replace)
|
||||
ind_to_replace = sorted(ind_to_replace)
|
||||
_remove_dup_initializers_from_model(model, model, ind_to_replace)
|
||||
|
||||
optimized_model_file_name = "optimized_" + model_file_name
|
||||
new_model = os.path.join(model_file_folder, optimized_model_file_name)
|
||||
1
examples/onnx/pytorch/summarization/requirements.txt
Normal file
1
examples/onnx/pytorch/summarization/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
torch >= 1.10
|
||||
@@ -20,7 +20,6 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,7 +45,7 @@ tokenizer_dict = {"facebook/bart-base": BartTokenizer}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.")
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
@@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
||||
model.eval()
|
||||
|
||||
ort_sess = None
|
||||
onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model))
|
||||
bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))
|
||||
|
||||
with torch.no_grad():
|
||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
|
||||
|
||||
# Test export here.
|
||||
summary_ids = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@@ -120,53 +118,54 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
|
||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||
)
|
||||
|
||||
if not ort_sess:
|
||||
torch.onnx.export(
|
||||
onnx_bart,
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
num_beams,
|
||||
max_length,
|
||||
model.config.decoder_start_token_id,
|
||||
),
|
||||
onnx_file_path,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
|
||||
output_names=["output_ids"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "seq"},
|
||||
"output_ids": {0: "batch", 1: "seq_out"},
|
||||
},
|
||||
verbose=False,
|
||||
strip_doc_string=False,
|
||||
example_outputs=summary_ids,
|
||||
)
|
||||
torch.onnx.export(
|
||||
bart_script_model,
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
num_beams,
|
||||
max_length,
|
||||
model.config.decoder_start_token_id,
|
||||
),
|
||||
onnx_file_path,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
|
||||
output_names=["output_ids"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "seq"},
|
||||
"output_ids": {0: "batch", 1: "seq_out"},
|
||||
},
|
||||
example_outputs=summary_ids,
|
||||
)
|
||||
|
||||
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
|
||||
logger.info("Model exported to {}".format(onnx_file_path))
|
||||
|
||||
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
|
||||
ort_out = ort_sess.run(
|
||||
None,
|
||||
{
|
||||
"input_ids": inputs["input_ids"].cpu().numpy(),
|
||||
"attention_mask": inputs["attention_mask"].cpu().numpy(),
|
||||
"num_beams": np.array(num_beams),
|
||||
"max_length": np.array(max_length),
|
||||
"decoder_start_token_id": np.array(model.config.decoder_start_token_id),
|
||||
},
|
||||
)
|
||||
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
|
||||
|
||||
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
|
||||
logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path))
|
||||
|
||||
print("========= Pass - Results are matched! =========")
|
||||
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
|
||||
ort_out = ort_sess.run(
|
||||
None,
|
||||
{
|
||||
"input_ids": inputs["input_ids"].cpu().numpy(),
|
||||
"attention_mask": inputs["attention_mask"].cpu().numpy(),
|
||||
"num_beams": np.array(num_beams),
|
||||
"max_length": np.array(max_length),
|
||||
"decoder_start_token_id": np.array(model.config.decoder_start_token_id),
|
||||
},
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
|
||||
|
||||
logger.info("Model outputs from torch and ONNX Runtime are similar.")
|
||||
logger.info("Success.")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
local_device = None
|
||||
local_max_length = 5
|
||||
local_num_beams = 4
|
||||
max_length = 5
|
||||
num_beams = 4
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -175,41 +174,31 @@ def main():
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
logger.setLevel(logging.ERROR)
|
||||
logger.setLevel(logging.INFO)
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
if args.model_name_or_path:
|
||||
model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device)
|
||||
else:
|
||||
raise ValueError("Make sure that model name has been passed")
|
||||
device = torch.device(args.device)
|
||||
|
||||
model, tokenizer = load_model_tokenizer(args.model_name_or_path, device)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
if args.device:
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise ValueError("CUDA is not available in this server.")
|
||||
|
||||
local_device = torch.device(args.device)
|
||||
else:
|
||||
local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
model.to(local_device)
|
||||
model.to(device)
|
||||
|
||||
if args.max_length:
|
||||
local_max_length = args.max_length
|
||||
max_length = args.max_length
|
||||
|
||||
if args.num_beams:
|
||||
local_num_beams = args.num_beams
|
||||
num_beams = args.num_beams
|
||||
|
||||
if args.output_file_path:
|
||||
output_name = args.output_file_path
|
||||
else:
|
||||
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond)
|
||||
output_name = "BART.onnx"
|
||||
|
||||
export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length)
|
||||
|
||||
logger.info("***** Running export *****")
|
||||
logger.info("Exporting model to ONNX")
|
||||
export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1 +0,0 @@
|
||||
torch >= 1.8
|
||||
Reference in New Issue
Block a user