[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (#11470)
* add flax roberta * make style * correct initialiazation * modify model to save weights * fix copied from * fix copied from * correct some more code * add more roberta models * Apply suggestions from code review * merge from master * finish * finish docs Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2ce0fb84cc
commit
084a187da3
@@ -45,7 +45,7 @@ from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
TensorType,
|
||||
@@ -105,6 +105,12 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -162,6 +168,10 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@@ -537,27 +547,76 @@ if __name__ == "__main__":
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
if data_args.line_by_line:
|
||||
# When using line_by_line, we just tokenize each nonempty line.
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=max_seq_length,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
else:
|
||||
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
||||
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
||||
# efficient when it receives the `special_tokens_mask`.
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||
# max_seq_length.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
total_length = (total_length // max_seq_length) * max_seq_length
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
||||
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
||||
# might be slower to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
tokenized_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
@@ -571,13 +630,7 @@ if __name__ == "__main__":
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxBertForMaskedLM.from_pretrained(
|
||||
"bert-base-cased",
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
@@ -602,8 +655,8 @@ if __name__ == "__main__":
|
||||
|
||||
# Store some constant
|
||||
nb_epochs = int(training_args.num_train_epochs)
|
||||
batch_size = int(training_args.train_batch_size)
|
||||
eval_batch_size = int(training_args.eval_batch_size)
|
||||
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
@@ -657,3 +710,8 @@ if __name__ == "__main__":
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
for name, value in eval_summary.items():
|
||||
summary_writer.scalar(name, value, epoch)
|
||||
|
||||
# save last checkpoint
|
||||
if jax.host_id() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
|
||||
Reference in New Issue
Block a user