Add Flax example tests (#14599)

* add test for glue

* add tests for clm

* fix clm test

* add summrization tests

* more tests

* fix few tests

* add test for t5 mlm

* fix t5 mlm test

* fix tests for multi device

* cleanup

* ci job

* fix metric file name

* make t5 more robust
This commit is contained in:
Suraj Patil
2021-12-06 10:48:58 +05:30
committed by GitHub
parent 803a8cd18f
commit c5bd732ac6
11 changed files with 553 additions and 6 deletions

View File

@@ -15,6 +15,7 @@
# limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import json
import logging
import os
import random
@@ -522,6 +523,13 @@ def main():
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
# save the eval metrics in json
if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()