[Flax] Refactor MLM (#12013)
* fix_torch_device_generate_test * remove @ * finish refactor Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
4674061b2a
commit
242ec31aa5
@@ -34,6 +34,7 @@ import numpy as np
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -185,9 +186,7 @@ class DataTrainingArguments:
|
|||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers/data/data_collator.py
|
@flax.struct.dataclass
|
||||||
# Letting here for now, let's discuss where it should live
|
|
||||||
@dataclass
|
|
||||||
class FlaxDataCollatorForLanguageModeling:
|
class FlaxDataCollatorForLanguageModeling:
|
||||||
"""
|
"""
|
||||||
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
||||||
@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
Args:
|
Args:
|
||||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||||
The tokenizer used for encoding the data.
|
The tokenizer used for encoding the data.
|
||||||
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
||||||
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
|
|
||||||
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
|
|
||||||
non-masked tokens and the value to predict for the masked token.
|
|
||||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
The probability with which to (randomly) mask tokens in the input.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
mlm: bool = True
|
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.mlm and self.tokenizer.mask_token is None:
|
if self.tokenizer.mask_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||||
@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
|
|
||||||
# If special token mask has been preprocessed, pop it from the dict.
|
# If special token mask has been preprocessed, pop it from the dict.
|
||||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||||
if self.mlm:
|
|
||||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
labels = batch["input_ids"].copy()
|
|
||||||
if self.tokenizer.pad_token_id is not None:
|
|
||||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
|
||||||
batch["labels"] = labels
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def mask_tokens(
|
def mask_tokens(
|
||||||
|
|||||||
Reference in New Issue
Block a user