Add Flax example tests (#14599)

* add test for glue

* add tests for clm

* fix clm test

* add summrization tests

* more tests

* fix few tests

* add test for t5 mlm

* fix t5 mlm test

* fix tests for multi device

* cleanup

* ci job

* fix metric file name

* make t5 more robust
This commit is contained in:
Suraj Patil
2021-12-06 10:48:58 +05:30
committed by GitHub
parent 803a8cd18f
commit c5bd732ac6
11 changed files with 553 additions and 6 deletions

View File

@@ -0,0 +1,7 @@
datasets >= 1.1.3
pytest
conllu
nltk
rouge-score
seqeval
tensorboard

View File

@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import json
import logging
import math
import os
@@ -672,6 +673,32 @@ def main():
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -20,7 +20,9 @@ text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
import json
import logging
import math
import os
import sys
import time
@@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)
if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
@@ -700,3 +702,41 @@ if __name__ == "__main__":
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s
https://huggingface.co/models?filter=t5
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import json
import logging
import os
import sys
@@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)
if __name__ == "__main__":
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
@@ -522,9 +523,7 @@ if __name__ == "__main__":
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
elif model_args.model_name_or_path:
config = T5Config.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
)
config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
@@ -617,6 +616,7 @@ if __name__ == "__main__":
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
else:
config.vocab_size = len(tokenizer)
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
# Data collator
@@ -808,3 +808,33 @@ if __name__ == "__main__":
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -18,6 +18,7 @@ Fine-tuning the library models for question answering.
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import json
import logging
import os
import random
@@ -911,6 +912,58 @@ def main():
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion
# Eval after training
if training_args.do_eval:
eval_metrics = {}
all_start_logits = []
all_end_logits = []
eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = p_eval_step(state, batch)
start_logits = np.array([pred for pred in chain(*predictions[0])])
end_logits = np.array([pred for pred in chain(*predictions[1])])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = eval_step(unreplicate(state), batch)
start_logits = np.array([pred for pred in predictions[0]])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
# concatenate the numpy array
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)
# delete the list of numpy arrays
del all_start_logits
del all_end_logits
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metrics = compute_metrics(prediction)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -18,6 +18,7 @@ Fine-tuning the library models for summarization.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import json
import logging
import os
import sys
@@ -816,6 +817,13 @@ def main():
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)
# save final metrics in json
if jax.process_index() == 0:
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
path = os.path.join(training_args.output_dir, "test_results.json")
with open(path, "w") as f:
json.dump(rouge_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,270 @@
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import sys
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-classification",
"language-modeling",
"summarization",
"token-classification",
"question-answering",
]
]
sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_flax
import run_flax_glue
import run_flax_ner
import run_mlm_flax
import run_qa
import run_summarization_flax
import run_t5_mlm_flax
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir, split="eval"):
results = {}
path = os.path.join(output_dir, f"{split}_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results
class ExamplesTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_train_steps=10
--num_warmup_steps=2
--seed=42
--max_length=128
""".split()
with patch.object(sys, "argv", testargs):
run_flax_glue.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
def test_run_clm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_flax.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--block_size 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_clm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 100)
@slow
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--test_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=50
--warmup_steps=8
--do_train
--do_eval
--do_predict
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_summarization_flax.main()
result = get_results(tmp_dir, split="test")
self.assertGreaterEqual(result["test_rouge1"], 10)
self.assertGreaterEqual(result["test_rouge2"], 2)
self.assertGreaterEqual(result["test_rougeL"], 7)
self.assertGreaterEqual(result["test_rougeLsum"], 7)
def test_run_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--overwrite_output_dir
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--logging_steps 2 --eval_steps 2
--do_train
--do_eval
--num_train_epochs=1
""".split()
with patch.object(sys, "argv", testargs):
run_mlm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 42)
@slow
def test_run_t5_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_t5_mlm_flax.py
--model_name_or_path t5-small
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_t5_mlm_flax.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.42)
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_flax_ner.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--logging_steps 2 --eval_steps 2
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs={epochs}
--seed 7
""".split()
with patch.object(sys, "argv", testargs):
run_flax_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_f1"], 0.3)
def test_run_qa(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa.py
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10
--warmup_steps=2
--do_train
--do_eval
--logging_steps 2 --eval_steps 2
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
""".split()
with patch.object(sys, "argv", testargs):
run_qa.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)

View File

@@ -15,6 +15,7 @@
# limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import json
import logging
import os
import random
@@ -522,6 +523,13 @@ def main():
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
# save the eval metrics in json
if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
import json
import logging
import os
import random
@@ -675,6 +676,42 @@ def main():
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# Eval after training
if training_args.do_eval:
eval_metrics = {}
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
predictions = np.array([pred for pred in chain(*predictions)])
labels = np.array([label for label in chain(*labels)])
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = np.array(batch.pop("labels"))
predictions = eval_step(unreplicate(state), batch)
labels[np.array(batch["attention_mask"]) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
eval_metrics = compute_metrics()
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()