[FlaxT5 Example] fix flax t5 example pretraining (#15835)
This commit is contained in:
committed by
GitHub
parent
01485ceec3
commit
10b76987fc
@@ -368,7 +368,9 @@ class FlaxDataCollatorForT5MLM:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
||||
input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
|
||||
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
||||
# masked tokens coming after sentinel tokens and should be removed
|
||||
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
|
||||
input_ids = np.concatenate(
|
||||
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user