[FlaxT5 Example] fix flax t5 example pretraining (#15835)
This commit is contained in:
committed by
GitHub
parent
01485ceec3
commit
10b76987fc
@@ -368,7 +368,9 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
||||||
input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
|
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
||||||
|
# masked tokens coming after sentinel tokens and should be removed
|
||||||
|
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
|
||||||
input_ids = np.concatenate(
|
input_ids = np.concatenate(
|
||||||
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
|
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user