From c5bd732ac6d262a0825751634cafbc08971842b0 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 6 Dec 2021 10:48:58 +0530 Subject: [PATCH] Add Flax example tests (#14599) * add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust --- .circleci/config.yml | 63 ++++ examples/flax/_tests_requirements.txt | 7 + .../flax/language-modeling/run_clm_flax.py | 27 ++ .../flax/language-modeling/run_mlm_flax.py | 42 ++- .../flax/language-modeling/run_t5_mlm_flax.py | 38 ++- examples/flax/question-answering/run_qa.py | 53 ++++ .../summarization/run_summarization_flax.py | 8 + examples/flax/test_examples.py | 270 ++++++++++++++++++ .../flax/text-classification/run_flax_glue.py | 8 + .../flax/token-classification/run_flax_ner.py | 37 +++ src/transformers/testing_utils.py | 6 +- 11 files changed, 553 insertions(+), 6 deletions(-) create mode 100644 examples/flax/_tests_requirements.txt create mode 100644 examples/flax/test_examples.py diff --git a/.circleci/config.yml b/.circleci/config.yml index e099814ea6..1b98723499 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -613,6 +613,69 @@ jobs: - store_artifacts: path: ~/transformers/reports + run_examples_flax: + working_directory: ~/transformers + docker: + - image: circleci/python:3.7 + environment: + OMP_NUM_THREADS: 1 + TRANSFORMERS_IS_CI: yes + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - restore_cache: + keys: + - v0.4-flax_examples-{{ checksum "setup.py" }} + - v0.4-{{ checksum "setup.py" }} + - run: pip install --upgrade pip + - run: sudo pip install .[flax,testing,sentencepiece] + - run: pip install -r examples/flax/_tests_requirements.txt + - save_cache: + key: v0.4-flax_examples-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee tests_output.txt + fi + - store_artifacts: + path: ~/transformers/flax_examples_output.txt + - store_artifacts: + path: ~/transformers/reports + + run_examples_flax_all: + working_directory: ~/transformers + docker: + - image: circleci/python:3.7 + environment: + OMP_NUM_THREADS: 1 + TRANSFORMERS_IS_CI: yes + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - restore_cache: + keys: + - v0.4-flax_examples-{{ checksum "setup.py" }} + - v0.4-{{ checksum "setup.py" }} + - run: pip install --upgrade pip + - run: sudo pip install .[flax,testing,sentencepiece] + - run: pip install -r examples/flax/_tests_requirements.txt + - save_cache: + key: v0.4-flax_examples-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: | + TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee examples_output.txt + - store_artifacts: + path: ~/transformers/flax_examples_output.txt + - store_artifacts: + path: ~/transformers/reports + run_tests_hub: working_directory: ~/transformers docker: diff --git a/examples/flax/_tests_requirements.txt b/examples/flax/_tests_requirements.txt new file mode 100644 index 0000000000..f9de455f62 --- /dev/null +++ b/examples/flax/_tests_requirements.txt @@ -0,0 +1,7 @@ +datasets >= 1.1.3 +pytest +conllu +nltk +rouge-score +seqeval +tensorboard \ No newline at end of file diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 50054a6044..d9a472c910 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. +import json import logging import math import os @@ -672,6 +673,32 @@ def main(): if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + # Eval after training + if training_args.do_eval: + eval_metrics = [] + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) + eval_steps = len(eval_dataset) // eval_batch_size + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = shard(next(eval_loader)) + metrics = p_eval_step(state.params, batch) + eval_metrics.append(metrics) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) + + try: + eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) + except OverflowError: + eval_metrics["perplexity"] = float("inf") + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + if __name__ == "__main__": main() diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 3be4bf387d..d10289c898 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -20,7 +20,9 @@ text file or a dataset. Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=masked-lm """ +import json import logging +import math import os import sys import time @@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): summary_writer.scalar(f"eval_{metric_name}", value, step) -if __name__ == "__main__": +def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. @@ -700,3 +702,41 @@ if __name__ == "__main__": tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + + # Eval after training + if training_args.do_eval: + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) + + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) + eval_normalizer = eval_metrics.pop("normalizer") + eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + + try: + perplexity = math.exp(eval_metrics["loss"]) + except OverflowError: + perplexity = float("inf") + eval_metrics["perplexity"] = perplexity + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + + +if __name__ == "__main__": + main() diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index b62a144449..b6b243f5cb 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s https://huggingface.co/models?filter=t5 """ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +import json import logging import os import sys @@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): summary_writer.scalar(f"eval_{metric_name}", value, step) -if __name__ == "__main__": +def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. @@ -522,9 +523,7 @@ if __name__ == "__main__": model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer) ) elif model_args.model_name_or_path: - config = T5Config.from_pretrained( - model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer) - ) + config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") @@ -617,6 +616,7 @@ if __name__ == "__main__": model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) else: + config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Data collator @@ -808,3 +808,33 @@ if __name__ == "__main__": tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + + # Eval after training + if training_args.do_eval: + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) + + # get eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + + +if __name__ == "__main__": + main() diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index df6e37f4a1..38d8229966 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -18,6 +18,7 @@ Fine-tuning the library models for question answering. """ # You can also adapt this script on your own question answering task. Pointers for this are left as comments. +import json import logging import os import random @@ -911,6 +912,58 @@ def main(): epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # endregion + # Eval after training + if training_args.do_eval: + eval_metrics = {} + all_start_logits = [] + all_end_logits = [] + + eva_loader = eval_data_collator(eval_dataset, eval_batch_size) + for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): + _ = batch.pop("example_id") + _ = batch.pop("offset_mapping") + predictions = p_eval_step(state, batch) + start_logits = np.array([pred for pred in chain(*predictions[0])]) + end_logits = np.array([pred for pred in chain(*predictions[1])]) + all_start_logits.append(start_logits) + all_end_logits.append(end_logits) + + # evaluate also on leftover examples (not divisible by batch_size) + num_leftover_samples = len(eval_dataset) % eval_batch_size + + # make sure leftover batch is evaluated on one device + if num_leftover_samples > 0 and jax.process_index() == 0: + # take leftover samples + batch = eval_dataset[-num_leftover_samples:] + batch = {k: np.array(v) for k, v in batch.items()} + _ = batch.pop("example_id") + _ = batch.pop("offset_mapping") + + predictions = eval_step(unreplicate(state), batch) + start_logits = np.array([pred for pred in predictions[0]]) + end_logits = np.array([pred for pred in predictions[1]]) + all_start_logits.append(start_logits) + all_end_logits.append(end_logits) + + max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor + + # concatenate the numpy array + start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) + end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) + + # delete the list of numpy arrays + del all_start_logits + del all_end_logits + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) + eval_metrics = compute_metrics(prediction) + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + if __name__ == "__main__": main() diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 6b0f3becda..25bde389a8 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -18,6 +18,7 @@ Fine-tuning the library models for summarization. """ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. +import json import logging import os import sys @@ -816,6 +817,13 @@ def main(): desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" logger.info(desc) + # save final metrics in json + if jax.process_index() == 0: + rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()} + path = os.path.join(training_args.output_dir, "test_results.json") + with open(path, "w") as f: + json.dump(rouge_metrics, f, indent=4, sort_keys=True) + if __name__ == "__main__": main() diff --git a/examples/flax/test_examples.py b/examples/flax/test_examples.py new file mode 100644 index 0000000000..f46e3ac75f --- /dev/null +++ b/examples/flax/test_examples.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +import os +import sys +from unittest.mock import patch + +from transformers.testing_utils import TestCasePlus, get_gpu_count, slow + + +SRC_DIRS = [ + os.path.join(os.path.dirname(__file__), dirname) + for dirname in [ + "text-classification", + "language-modeling", + "summarization", + "token-classification", + "question-answering", + ] +] +sys.path.extend(SRC_DIRS) + + +if SRC_DIRS is not None: + import run_clm_flax + import run_flax_glue + import run_flax_ner + import run_mlm_flax + import run_qa + import run_summarization_flax + import run_t5_mlm_flax + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +def get_setup_file(): + parser = argparse.ArgumentParser() + parser.add_argument("-f") + args = parser.parse_args() + return args.f + + +def get_results(output_dir, split="eval"): + results = {} + path = os.path.join(output_dir, f"{split}_results.json") + if os.path.exists(path): + with open(path, "r") as f: + results = json.load(f) + else: + raise ValueError(f"can't find {path}") + return results + + +class ExamplesTests(TestCasePlus): + def test_run_glue(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_glue.py + --model_name_or_path distilbert-base-uncased + --output_dir {tmp_dir} + --train_file ./tests/fixtures/tests_samples/MRPC/train.csv + --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --learning_rate=1e-4 + --max_train_steps=10 + --num_warmup_steps=2 + --seed=42 + --max_length=128 + """.split() + + with patch.object(sys, "argv", testargs): + run_flax_glue.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.75) + + def test_run_clm(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_clm_flax.py + --model_name_or_path distilgpt2 + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --do_train + --do_eval + --block_size 128 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --num_train_epochs 2 + --logging_steps 2 --eval_steps 2 + --output_dir {tmp_dir} + --overwrite_output_dir + """.split() + + with patch.object(sys, "argv", testargs): + run_clm_flax.main() + result = get_results(tmp_dir) + self.assertLess(result["eval_perplexity"], 100) + + @slow + def test_run_summarization(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_summarization.py + --model_name_or_path t5-small + --train_file tests/fixtures/tests_samples/xsum/sample.json + --validation_file tests/fixtures/tests_samples/xsum/sample.json + --test_file tests/fixtures/tests_samples/xsum/sample.json + --output_dir {tmp_dir} + --overwrite_output_dir + --max_steps=50 + --warmup_steps=8 + --do_train + --do_eval + --do_predict + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --predict_with_generate + """.split() + + with patch.object(sys, "argv", testargs): + run_summarization_flax.main() + result = get_results(tmp_dir, split="test") + self.assertGreaterEqual(result["test_rouge1"], 10) + self.assertGreaterEqual(result["test_rouge2"], 2) + self.assertGreaterEqual(result["test_rougeL"], 7) + self.assertGreaterEqual(result["test_rougeLsum"], 7) + + def test_run_mlm(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_mlm.py + --model_name_or_path distilroberta-base + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --output_dir {tmp_dir} + --overwrite_output_dir + --max_seq_length 128 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --logging_steps 2 --eval_steps 2 + --do_train + --do_eval + --num_train_epochs=1 + """.split() + + with patch.object(sys, "argv", testargs): + run_mlm_flax.main() + result = get_results(tmp_dir) + self.assertLess(result["eval_perplexity"], 42) + + @slow + def test_run_t5_mlm(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_t5_mlm_flax.py + --model_name_or_path t5-small + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --do_train + --do_eval + --max_seq_length 128 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --num_train_epochs 2 + --logging_steps 2 --eval_steps 2 + --output_dir {tmp_dir} + --overwrite_output_dir + """.split() + + with patch.object(sys, "argv", testargs): + run_t5_mlm_flax.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.42) + + def test_run_ner(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu + epochs = 7 if get_gpu_count() > 1 else 2 + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_flax_ner.py + --model_name_or_path bert-base-uncased + --train_file tests/fixtures/tests_samples/conll/sample.json + --validation_file tests/fixtures/tests_samples/conll/sample.json + --output_dir {tmp_dir} + --overwrite_output_dir + --do_train + --do_eval + --warmup_steps=2 + --learning_rate=2e-4 + --logging_steps 2 --eval_steps 2 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=2 + --num_train_epochs={epochs} + --seed 7 + """.split() + + with patch.object(sys, "argv", testargs): + run_flax_ner.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.75) + self.assertGreaterEqual(result["eval_f1"], 0.3) + + def test_run_qa(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_qa.py + --model_name_or_path bert-base-uncased + --version_2_with_negative + --train_file tests/fixtures/tests_samples/SQUAD/sample.json + --validation_file tests/fixtures/tests_samples/SQUAD/sample.json + --output_dir {tmp_dir} + --overwrite_output_dir + --max_steps=10 + --warmup_steps=2 + --do_train + --do_eval + --logging_steps 2 --eval_steps 2 + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + """.split() + + with patch.object(sys, "argv", testargs): + run_qa.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_f1"], 30) + self.assertGreaterEqual(result["eval_exact"], 30) diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index ccccfbea96..9044331db5 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -15,6 +15,7 @@ # limitations under the License. """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" import argparse +import json import logging import os import random @@ -522,6 +523,13 @@ def main(): if args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) + # save the eval metrics in json + if jax.process_index() == 0: + eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} + path = os.path.join(args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metric, f, indent=4, sort_keys=True) + if __name__ == "__main__": main() diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index ceb2b71624..3a49ff1fc3 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" +import json import logging import os import random @@ -675,6 +676,42 @@ def main(): repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" + # Eval after training + if training_args.do_eval: + eval_metrics = {} + eval_loader = eval_data_collator(eval_dataset, eval_batch_size) + for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): + labels = batch.pop("labels") + predictions = p_eval_step(state, batch) + predictions = np.array([pred for pred in chain(*predictions)]) + labels = np.array([label for label in chain(*labels)]) + labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 + preds, refs = get_labels(predictions, labels) + metric.add_batch(predictions=preds, references=refs) + + # evaluate also on leftover examples (not divisible by batch_size) + num_leftover_samples = len(eval_dataset) % eval_batch_size + + # make sure leftover batch is evaluated on one device + if num_leftover_samples > 0 and jax.process_index() == 0: + # take leftover samples + batch = eval_dataset[-num_leftover_samples:] + batch = {k: np.array(v) for k, v in batch.items()} + + labels = np.array(batch.pop("labels")) + predictions = eval_step(unreplicate(state), batch) + labels[np.array(batch["attention_mask"]) == 0] = -100 + preds, refs = get_labels(predictions, labels) + metric.add_batch(predictions=preds, references=refs) + + eval_metrics = compute_metrics() + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + if __name__ == "__main__": main() diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index c9032530ca..b0ba7bb284 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -600,7 +600,7 @@ def require_deepspeed(test_case): def get_gpu_count(): """ - Return the number of available gpus (regardless of whether torch or tf is used) + Return the number of available gpus (regardless of whether torch, tf or jax is used) """ if is_torch_available(): import torch @@ -610,6 +610,10 @@ def get_gpu_count(): import tensorflow as tf return len(tf.config.list_physical_devices("GPU")) + elif is_flax_available(): + import jax + + return jax.device_count() else: return 0