[Flax] Refactor MLM (#12013)
* fix_torch_device_generate_test * remove @ * finish refactor Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
4674061b2a
commit
242ec31aa5
@@ -34,6 +34,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
@@ -185,9 +186,7 @@ class DataTrainingArguments:
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
# Adapted from transformers/data/data_collator.py
|
||||
# Letting here for now, let's discuss where it should live
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxDataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
||||
@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
|
||||
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
|
||||
non-masked tokens and the value to predict for the masked token.
|
||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
||||
The probability with which to (randomly) mask tokens in the input.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mlm and self.tokenizer.mask_token is None:
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||
@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
if self.mlm:
|
||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
)
|
||||
else:
|
||||
labels = batch["input_ids"].copy()
|
||||
if self.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
)
|
||||
return batch
|
||||
|
||||
def mask_tokens(
|
||||
|
||||
Reference in New Issue
Block a user