[T5] make decoder input ids optional for t5 training (#3521)
* make decoder input ids optional for t5 training * lm_lables should not be shifted in t5 * add tests * finish shift right functionality for PT T5 * move shift right to correct class * cleaner code * replace -100 values with pad token id * add assert statement * remove unnecessary for loop * make style
This commit is contained in:
committed by
GitHub
parent
5b44e0a31b
commit
75ec6c9e3a
@@ -52,7 +52,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_ids=[1],
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
@@ -71,7 +71,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
@@ -97,7 +97,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_ids=self.eos_token_ids,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user