Make T5 compatible with ONNX (#5518)
* Default decoder inputs to encoder ones for T5 if neither are specified. * Fixing typo, now all tests are passing. * Changing einsum to operations supported by onnx * Adding a test to ensure T5 can be exported to onnx op>9 * Modified test for onnx export to make it faster * Styling changes. * Styling changes. * Changing notation for matrix multiplication Co-authored-by: Abel Riboulot <tkai@protomail.com>
This commit is contained in:
@@ -358,7 +358,10 @@ class T5Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present_key_value_state = (None,)
|
present_key_value_state = (None,)
|
||||||
|
|
||||||
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
|
# (bs, n_heads, qlen, klen)
|
||||||
|
scores = torch.matmul(
|
||||||
|
q, k.transpose(3, 2)
|
||||||
|
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9
|
||||||
|
|
||||||
if position_bias is None:
|
if position_bias is None:
|
||||||
if not self.has_relative_attention_bias:
|
if not self.has_relative_attention_bias:
|
||||||
@@ -818,7 +821,8 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||||
`T5 Training <./t5.html#training>`__.
|
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
|
decoder_input_ids takes the value of input_ids.
|
||||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
@@ -837,7 +841,8 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
|
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
|
||||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
|
decoder_inputs_embeds takes the value of inputs_embeds.
|
||||||
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
@@ -934,7 +939,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
>>> model = T5Model.from_pretrained('t5-small')
|
>>> model = T5Model.from_pretrained('t5-small')
|
||||||
|
|
||||||
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
|
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
>>> outputs = model(input_ids=input_ids)
|
||||||
|
|
||||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
"""
|
"""
|
||||||
@@ -953,6 +958,12 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
# If the model is only provided with either input_ids or inputs_embeds,
|
||||||
|
# use them as the inputs of the decoder. self.encoder checks for input_ids XOR inputs_embeds
|
||||||
|
if (decoder_input_ids is None) and (decoder_inputs_embeds is None):
|
||||||
|
decoder_input_ids = input_ids
|
||||||
|
decoder_inputs_embeds = inputs_embeds
|
||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if decoder_past_key_value_states is not None:
|
||||||
@@ -1076,7 +1087,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||||
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
||||||
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
|
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
|
>>> outputs = model(input_ids=input_ids, labels=input_ids)
|
||||||
>>> loss, prediction_scores = outputs[:2]
|
>>> loss, prediction_scores = outputs[:2]
|
||||||
|
|
||||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||||
|
|||||||
@@ -351,6 +351,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = T5Model.from_pretrained(model_name)
|
model = T5Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_export_to_onnx(self):
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
model = T5Model(config_and_inputs[0])
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
torch.onnx.export(
|
||||||
|
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class T5ModelIntegrationTests(unittest.TestCase):
|
class T5ModelIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user