From 242ec31aa59b358e631d981b545fd08330584ea8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Jun 2021 16:31:32 +0100 Subject: [PATCH] [Flax] Refactor MLM (#12013) * fix_torch_device_generate_test * remove @ * finish refactor Co-authored-by: Patrick von Platen --- .../flax/language-modeling/run_mlm_flax.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 6be1f7ed18..dddd6ce478 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -34,6 +34,7 @@ import numpy as np from datasets import load_dataset from tqdm import tqdm +import flax import jax import jax.numpy as jnp import optax @@ -185,9 +186,7 @@ class DataTrainingArguments: assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." -# Adapted from transformers/data/data_collator.py -# Letting here for now, let's discuss where it should live -@dataclass +@flax.struct.dataclass class FlaxDataCollatorForLanguageModeling: """ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they @@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling: Args: tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): The tokenizer used for encoding the data. - mlm (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the - inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for - non-masked tokens and the value to predict for the masked token. mlm_probability (:obj:`float`, `optional`, defaults to 0.15): - The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. + The probability with which to (randomly) mask tokens in the input. .. note:: @@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling: """ tokenizer: PreTrainedTokenizerBase - mlm: bool = True mlm_probability: float = 0.15 def __post_init__(self): - if self.mlm and self.tokenizer.mask_token is None: + if self.tokenizer.mask_token is None: raise ValueError( "This tokenizer does not have a mask token which is necessary for masked language modeling. " "You should pass `mlm=False` to train on causal language modeling instead." @@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling: # If special token mask has been preprocessed, pop it from the dict. special_tokens_mask = batch.pop("special_tokens_mask", None) - if self.mlm: - batch["input_ids"], batch["labels"] = self.mask_tokens( - batch["input_ids"], special_tokens_mask=special_tokens_mask - ) - else: - labels = batch["input_ids"].copy() - if self.tokenizer.pad_token_id is not None: - labels[labels == self.tokenizer.pad_token_id] = -100 - batch["labels"] = labels + + batch["input_ids"], batch["labels"] = self.mask_tokens( + batch["input_ids"], special_tokens_mask=special_tokens_mask + ) return batch def mask_tokens(