Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -613,6 +613,69 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: ~/transformers/reports
|
path: ~/transformers/reports
|
||||||
|
|
||||||
|
run_examples_flax:
|
||||||
|
working_directory: ~/transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.7
|
||||||
|
environment:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
TRANSFORMERS_IS_CI: yes
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- v0.4-flax_examples-{{ checksum "setup.py" }}
|
||||||
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
|
- run: pip install --upgrade pip
|
||||||
|
- run: sudo pip install .[flax,testing,sentencepiece]
|
||||||
|
- run: pip install -r examples/flax/_tests_requirements.txt
|
||||||
|
- save_cache:
|
||||||
|
key: v0.4-flax_examples-{{ checksum "setup.py" }}
|
||||||
|
paths:
|
||||||
|
- '~/.cache/pip'
|
||||||
|
- run: python utils/tests_fetcher.py --filters examples tests | tee test_preparation.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/test_preparation.txt
|
||||||
|
- run: |
|
||||||
|
if [ -f test_list.txt ]; then
|
||||||
|
python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee tests_output.txt
|
||||||
|
fi
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/flax_examples_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/reports
|
||||||
|
|
||||||
|
run_examples_flax_all:
|
||||||
|
working_directory: ~/transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.7
|
||||||
|
environment:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
TRANSFORMERS_IS_CI: yes
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- v0.4-flax_examples-{{ checksum "setup.py" }}
|
||||||
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
|
- run: pip install --upgrade pip
|
||||||
|
- run: sudo pip install .[flax,testing,sentencepiece]
|
||||||
|
- run: pip install -r examples/flax/_tests_requirements.txt
|
||||||
|
- save_cache:
|
||||||
|
key: v0.4-flax_examples-{{ checksum "setup.py" }}
|
||||||
|
paths:
|
||||||
|
- '~/.cache/pip'
|
||||||
|
- run: |
|
||||||
|
TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_flax ./examples/flax/ | tee examples_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/flax_examples_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/reports
|
||||||
|
|
||||||
run_tests_hub:
|
run_tests_hub:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
docker:
|
docker:
|
||||||
|
|||||||
7
examples/flax/_tests_requirements.txt
Normal file
7
examples/flax/_tests_requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
datasets >= 1.1.3
|
||||||
|
pytest
|
||||||
|
conllu
|
||||||
|
nltk
|
||||||
|
rouge-score
|
||||||
|
seqeval
|
||||||
|
tensorboard
|
||||||
@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -672,6 +673,32 @@ def main():
|
|||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
|
|
||||||
|
# Eval after training
|
||||||
|
if training_args.do_eval:
|
||||||
|
eval_metrics = []
|
||||||
|
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||||
|
eval_steps = len(eval_dataset) // eval_batch_size
|
||||||
|
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||||
|
# Model forward
|
||||||
|
batch = shard(next(eval_loader))
|
||||||
|
metrics = p_eval_step(state.params, batch)
|
||||||
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
|
# normalize eval metrics
|
||||||
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
|
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
|
||||||
|
|
||||||
|
try:
|
||||||
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||||
|
except OverflowError:
|
||||||
|
eval_metrics["perplexity"] = float("inf")
|
||||||
|
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ text file or a dataset.
|
|||||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||||
https://huggingface.co/models?filter=masked-lm
|
https://huggingface.co/models?filter=masked-lm
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|||||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
@@ -700,3 +702,41 @@ if __name__ == "__main__":
|
|||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
|
|
||||||
|
# Eval after training
|
||||||
|
if training_args.do_eval:
|
||||||
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
|
eval_metrics = []
|
||||||
|
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||||
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||||
|
|
||||||
|
# Model forward
|
||||||
|
model_inputs = shard(model_inputs.data)
|
||||||
|
metrics = p_eval_step(state.params, model_inputs)
|
||||||
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
|
# normalize eval metrics
|
||||||
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
|
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||||
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
|
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
|
try:
|
||||||
|
perplexity = math.exp(eval_metrics["loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
eval_metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s
|
|||||||
https://huggingface.co/models?filter=t5
|
https://huggingface.co/models?filter=t5
|
||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|||||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
@@ -522,9 +523,7 @@ if __name__ == "__main__":
|
|||||||
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
||||||
)
|
)
|
||||||
elif model_args.model_name_or_path:
|
elif model_args.model_name_or_path:
|
||||||
config = T5Config.from_pretrained(
|
config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = CONFIG_MAPPING[model_args.model_type]()
|
config = CONFIG_MAPPING[model_args.model_type]()
|
||||||
logger.warning("You are instantiating a new config instance from scratch.")
|
logger.warning("You are instantiating a new config instance from scratch.")
|
||||||
@@ -617,6 +616,7 @@ if __name__ == "__main__":
|
|||||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
config.vocab_size = len(tokenizer)
|
||||||
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
@@ -808,3 +808,33 @@ if __name__ == "__main__":
|
|||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
|
|
||||||
|
# Eval after training
|
||||||
|
if training_args.do_eval:
|
||||||
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
|
eval_metrics = []
|
||||||
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||||
|
model_inputs = data_collator(samples)
|
||||||
|
|
||||||
|
# Model forward
|
||||||
|
model_inputs = shard(model_inputs.data)
|
||||||
|
metrics = p_eval_step(state.params, model_inputs)
|
||||||
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
|
# get eval metrics
|
||||||
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
|
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
|
||||||
|
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Fine-tuning the library models for question answering.
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -911,6 +912,58 @@ def main():
|
|||||||
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# Eval after training
|
||||||
|
if training_args.do_eval:
|
||||||
|
eval_metrics = {}
|
||||||
|
all_start_logits = []
|
||||||
|
all_end_logits = []
|
||||||
|
|
||||||
|
eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
|
||||||
|
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
|
||||||
|
_ = batch.pop("example_id")
|
||||||
|
_ = batch.pop("offset_mapping")
|
||||||
|
predictions = p_eval_step(state, batch)
|
||||||
|
start_logits = np.array([pred for pred in chain(*predictions[0])])
|
||||||
|
end_logits = np.array([pred for pred in chain(*predictions[1])])
|
||||||
|
all_start_logits.append(start_logits)
|
||||||
|
all_end_logits.append(end_logits)
|
||||||
|
|
||||||
|
# evaluate also on leftover examples (not divisible by batch_size)
|
||||||
|
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
||||||
|
|
||||||
|
# make sure leftover batch is evaluated on one device
|
||||||
|
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||||
|
# take leftover samples
|
||||||
|
batch = eval_dataset[-num_leftover_samples:]
|
||||||
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
|
_ = batch.pop("example_id")
|
||||||
|
_ = batch.pop("offset_mapping")
|
||||||
|
|
||||||
|
predictions = eval_step(unreplicate(state), batch)
|
||||||
|
start_logits = np.array([pred for pred in predictions[0]])
|
||||||
|
end_logits = np.array([pred for pred in predictions[1]])
|
||||||
|
all_start_logits.append(start_logits)
|
||||||
|
all_end_logits.append(end_logits)
|
||||||
|
|
||||||
|
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
|
||||||
|
|
||||||
|
# concatenate the numpy array
|
||||||
|
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
|
||||||
|
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)
|
||||||
|
|
||||||
|
# delete the list of numpy arrays
|
||||||
|
del all_start_logits
|
||||||
|
del all_end_logits
|
||||||
|
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||||
|
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
|
||||||
|
eval_metrics = compute_metrics(prediction)
|
||||||
|
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Fine-tuning the library models for summarization.
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -816,6 +817,13 @@ def main():
|
|||||||
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
||||||
logger.info(desc)
|
logger.info(desc)
|
||||||
|
|
||||||
|
# save final metrics in json
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "test_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(rouge_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
270
examples/flax/test_examples.py
Normal file
270
examples/flax/test_examples.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow
|
||||||
|
|
||||||
|
|
||||||
|
SRC_DIRS = [
|
||||||
|
os.path.join(os.path.dirname(__file__), dirname)
|
||||||
|
for dirname in [
|
||||||
|
"text-classification",
|
||||||
|
"language-modeling",
|
||||||
|
"summarization",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
sys.path.extend(SRC_DIRS)
|
||||||
|
|
||||||
|
|
||||||
|
if SRC_DIRS is not None:
|
||||||
|
import run_clm_flax
|
||||||
|
import run_flax_glue
|
||||||
|
import run_flax_ner
|
||||||
|
import run_mlm_flax
|
||||||
|
import run_qa
|
||||||
|
import run_summarization_flax
|
||||||
|
import run_t5_mlm_flax
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_setup_file():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-f")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
|
def get_results(output_dir, split="eval"):
|
||||||
|
results = {}
|
||||||
|
path = os.path.join(output_dir, f"{split}_results.json")
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"can't find {path}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplesTests(TestCasePlus):
|
||||||
|
def test_run_glue(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_glue.py
|
||||||
|
--model_name_or_path distilbert-base-uncased
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
|
||||||
|
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--learning_rate=1e-4
|
||||||
|
--max_train_steps=10
|
||||||
|
--num_warmup_steps=2
|
||||||
|
--seed=42
|
||||||
|
--max_length=128
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_flax_glue.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
|
|
||||||
|
def test_run_clm(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_clm_flax.py
|
||||||
|
--model_name_or_path distilgpt2
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--validation_file ./tests/fixtures/sample_text.txt
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--block_size 128
|
||||||
|
--per_device_train_batch_size 4
|
||||||
|
--per_device_eval_batch_size 4
|
||||||
|
--num_train_epochs 2
|
||||||
|
--logging_steps 2 --eval_steps 2
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_clm_flax.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertLess(result["eval_perplexity"], 100)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_run_summarization(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_summarization.py
|
||||||
|
--model_name_or_path t5-small
|
||||||
|
--train_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
|
--test_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--max_steps=50
|
||||||
|
--warmup_steps=8
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--do_predict
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--predict_with_generate
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_summarization_flax.main()
|
||||||
|
result = get_results(tmp_dir, split="test")
|
||||||
|
self.assertGreaterEqual(result["test_rouge1"], 10)
|
||||||
|
self.assertGreaterEqual(result["test_rouge2"], 2)
|
||||||
|
self.assertGreaterEqual(result["test_rougeL"], 7)
|
||||||
|
self.assertGreaterEqual(result["test_rougeLsum"], 7)
|
||||||
|
|
||||||
|
def test_run_mlm(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_mlm.py
|
||||||
|
--model_name_or_path distilroberta-base
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--validation_file ./tests/fixtures/sample_text.txt
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--max_seq_length 128
|
||||||
|
--per_device_train_batch_size 4
|
||||||
|
--per_device_eval_batch_size 4
|
||||||
|
--logging_steps 2 --eval_steps 2
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--num_train_epochs=1
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_mlm_flax.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertLess(result["eval_perplexity"], 42)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_run_t5_mlm(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_t5_mlm_flax.py
|
||||||
|
--model_name_or_path t5-small
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--validation_file ./tests/fixtures/sample_text.txt
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--max_seq_length 128
|
||||||
|
--per_device_train_batch_size 4
|
||||||
|
--per_device_eval_batch_size 4
|
||||||
|
--num_train_epochs 2
|
||||||
|
--logging_steps 2 --eval_steps 2
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_t5_mlm_flax.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.42)
|
||||||
|
|
||||||
|
def test_run_ner(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
|
epochs = 7 if get_gpu_count() > 1 else 2
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_flax_ner.py
|
||||||
|
--model_name_or_path bert-base-uncased
|
||||||
|
--train_file tests/fixtures/tests_samples/conll/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/conll/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--warmup_steps=2
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--logging_steps 2 --eval_steps 2
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=2
|
||||||
|
--num_train_epochs={epochs}
|
||||||
|
--seed 7
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_flax_ner.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
|
self.assertGreaterEqual(result["eval_f1"], 0.3)
|
||||||
|
|
||||||
|
def test_run_qa(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_qa.py
|
||||||
|
--model_name_or_path bert-base-uncased
|
||||||
|
--version_2_with_negative
|
||||||
|
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--max_steps=10
|
||||||
|
--warmup_steps=2
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--logging_steps 2 --eval_steps 2
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_qa.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||||
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -522,6 +523,13 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||||
|
|
||||||
|
# save the eval metrics in json
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
|
||||||
|
path = os.path.join(args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metric, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
|
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -675,6 +676,42 @@ def main():
|
|||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||||
|
|
||||||
|
# Eval after training
|
||||||
|
if training_args.do_eval:
|
||||||
|
eval_metrics = {}
|
||||||
|
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
|
||||||
|
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
predictions = p_eval_step(state, batch)
|
||||||
|
predictions = np.array([pred for pred in chain(*predictions)])
|
||||||
|
labels = np.array([label for label in chain(*labels)])
|
||||||
|
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
|
||||||
|
preds, refs = get_labels(predictions, labels)
|
||||||
|
metric.add_batch(predictions=preds, references=refs)
|
||||||
|
|
||||||
|
# evaluate also on leftover examples (not divisible by batch_size)
|
||||||
|
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
||||||
|
|
||||||
|
# make sure leftover batch is evaluated on one device
|
||||||
|
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||||
|
# take leftover samples
|
||||||
|
batch = eval_dataset[-num_leftover_samples:]
|
||||||
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
labels = np.array(batch.pop("labels"))
|
||||||
|
predictions = eval_step(unreplicate(state), batch)
|
||||||
|
labels[np.array(batch["attention_mask"]) == 0] = -100
|
||||||
|
preds, refs = get_labels(predictions, labels)
|
||||||
|
metric.add_batch(predictions=preds, references=refs)
|
||||||
|
|
||||||
|
eval_metrics = compute_metrics()
|
||||||
|
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||||
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -600,7 +600,7 @@ def require_deepspeed(test_case):
|
|||||||
|
|
||||||
def get_gpu_count():
|
def get_gpu_count():
|
||||||
"""
|
"""
|
||||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||||
"""
|
"""
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -610,6 +610,10 @@ def get_gpu_count():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
return len(tf.config.list_physical_devices("GPU"))
|
return len(tf.config.list_physical_devices("GPU"))
|
||||||
|
elif is_flax_available():
|
||||||
|
import jax
|
||||||
|
|
||||||
|
return jax.device_count()
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user