Add Flax example tests (#14599)
* add test for glue * add tests for clm * fix clm test * add summrization tests * more tests * fix few tests * add test for t5 mlm * fix t5 mlm test * fix tests for multi device * cleanup * ci job * fix metric file name * make t5 more robust
This commit is contained in:
@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=causal-lm
|
||||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -672,6 +673,32 @@ def main():
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||
|
||||
# Eval after training
|
||||
if training_args.do_eval:
|
||||
eval_metrics = []
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = shard(next(eval_loader))
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
|
||||
|
||||
try:
|
||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||
except OverflowError:
|
||||
eval_metrics["perplexity"] = float("inf")
|
||||
|
||||
if jax.process_index() == 0:
|
||||
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -20,7 +20,9 @@ text file or a dataset.
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=masked-lm
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -271,7 +273,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
@@ -700,3 +702,41 @@ if __name__ == "__main__":
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||
|
||||
# Eval after training
|
||||
if training_args.do_eval:
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
try:
|
||||
perplexity = math.exp(eval_metrics["loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
eval_metrics["perplexity"] = perplexity
|
||||
|
||||
if jax.process_index() == 0:
|
||||
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -20,6 +20,7 @@ Here is the full list of checkpoints on the hub that can be pretrained by this s
|
||||
https://huggingface.co/models?filter=t5
|
||||
"""
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -401,7 +402,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
@@ -522,9 +523,7 @@ if __name__ == "__main__":
|
||||
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
config = T5Config.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
||||
)
|
||||
config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
@@ -617,6 +616,7 @@ if __name__ == "__main__":
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
config.vocab_size = len(tokenizer)
|
||||
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
# Data collator
|
||||
@@ -808,3 +808,33 @@ if __name__ == "__main__":
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||
|
||||
# Eval after training
|
||||
if training_args.do_eval:
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# get eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
||||
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user