[Flax] Refactor MLM (#12013)

* fix_torch_device_generate_test

* remove @

* finish refactor

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-06-03 16:31:32 +01:00
committed by GitHub
parent 4674061b2a
commit 242ec31aa5

View File

@@ -34,6 +34,7 @@ import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import optax
@@ -185,9 +186,7 @@ class DataTrainingArguments:
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
# Adapted from transformers/data/data_collator.py
# Letting here for now, let's discuss where it should live
@dataclass
@flax.struct.dataclass
class FlaxDataCollatorForLanguageModeling:
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling:
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
The probability with which to (randomly) mask tokens in the input.
.. note::
@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling:
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling:
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].copy()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
return batch
def mask_tokens(