Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -18,6 +18,7 @@ Fine-tuning the library models for summarization.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -816,6 +817,13 @@ def main():
|
||||
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
||||
logger.info(desc)
|
||||
|
||||
# save final metrics in json
|
||||
if jax.process_index() == 0:
|
||||
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()}
|
||||
path = os.path.join(training_args.output_dir, "test_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(rouge_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user