[T5] make decoder input ids optional for t5 training (#3521)
* make decoder input ids optional for t5 training * lm_lables should not be shifted in t5 * add tests * finish shift right functionality for PT T5 * move shift right to correct class * cleaner code * replace -100 values with pad token id * add assert statement * remove unnecessary for loop * make style
This commit is contained in:
committed by
GitHub
parent
5b44e0a31b
commit
75ec6c9e3a
@@ -468,7 +468,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
length_penalty=1.0,
|
||||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.eos_token_ids[0],
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
)
|
||||
|
||||
decoded = [
|
||||
|
||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import T5Config, T5Model, T5ForConditionalGeneration
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
@@ -57,8 +58,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_ids=[1],
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -78,8 +80,9 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.scope = scope
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
@@ -106,9 +109,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_ids=self.eos_token_ids,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -123,6 +127,39 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# make sure that lm_labels are correctly padded from the right
|
||||
lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
|
||||
|
||||
# add casaul pad token mask
|
||||
triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
|
||||
lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
|
||||
decoder_input_ids = model._shift_right(lm_labels)
|
||||
|
||||
for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
|
||||
# first item
|
||||
self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
|
||||
if i < decoder_input_ids_slice.shape[-1]:
|
||||
if i < decoder_input_ids.shape[-1] - 1:
|
||||
# items before diagonal
|
||||
self.parent.assertListEqual(
|
||||
decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
|
||||
)
|
||||
# pad items after diagonal
|
||||
if i < decoder_input_ids.shape[-1] - 2:
|
||||
self.parent.assertListEqual(
|
||||
decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
|
||||
)
|
||||
else:
|
||||
# all items after square
|
||||
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
|
||||
|
||||
def create_and_check_t5_model(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
@@ -197,6 +234,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_shift_right(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
|
||||
|
||||
def test_t5_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model(*config_and_inputs)
|
||||
|
||||
@@ -52,7 +52,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_ids=[1],
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
@@ -71,7 +71,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
@@ -97,7 +97,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_ids=self.eos_token_ids,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user