[T5] make decoder input ids optional for t5 training (#3521)

* make decoder input ids optional for t5 training

* lm_lables should not be shifted in t5

* add tests

* finish shift right functionality for PT T5

* move shift right to correct class

* cleaner code

* replace -100 values with pad token id

* add assert statement

* remove unnecessary for loop

* make style
This commit is contained in:
Patrick von Platen
2020-03-30 13:45:26 +02:00
committed by GitHub
parent 5b44e0a31b
commit 75ec6c9e3a
6 changed files with 79 additions and 18 deletions

View File

@@ -52,7 +52,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
eos_token_id=1,
pad_token_id=0,
scope=None,
):
@@ -71,7 +71,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_ids = eos_token_ids
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.scope = scope
@@ -97,7 +97,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
)