Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -675,6 +676,42 @@ def main():
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||
|
||||
# Eval after training
|
||||
if training_args.do_eval:
|
||||
eval_metrics = {}
|
||||
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
|
||||
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
|
||||
labels = batch.pop("labels")
|
||||
predictions = p_eval_step(state, batch)
|
||||
predictions = np.array([pred for pred in chain(*predictions)])
|
||||
labels = np.array([label for label in chain(*labels)])
|
||||
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(predictions=preds, references=refs)
|
||||
|
||||
# evaluate also on leftover examples (not divisible by batch_size)
|
||||
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
||||
|
||||
# make sure leftover batch is evaluated on one device
|
||||
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||
# take leftover samples
|
||||
batch = eval_dataset[-num_leftover_samples:]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
labels = np.array(batch.pop("labels"))
|
||||
predictions = eval_step(unreplicate(state), batch)
|
||||
labels[np.array(batch["attention_mask"]) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(predictions=preds, references=refs)
|
||||
|
||||
eval_metrics = compute_metrics()
|
||||
|
||||
if jax.process_index() == 0:
|
||||
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user