Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -522,6 +523,13 @@ def main():
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
# save the eval metrics in json
|
||||
if jax.process_index() == 0:
|
||||
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
|
||||
path = os.path.join(args.output_dir, "eval_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(eval_metric, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user