[s2sTrainer] test + code cleanup (#7467)

This commit is contained in:
Sam Shleifer
2020-10-01 00:33:01 -04:00
committed by GitHub
parent 097049b81b
commit 48f23f92a8
5 changed files with 102 additions and 116 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
import os
import sys
@@ -29,10 +28,13 @@ from utils import (
assert_all_frozen,
calculate_bleu,
calculate_rouge,
freeze_embeds,
freeze_params,
lmap,
save_json,
trim_batch,
use_task_specific_params,
write_txt_file,
)
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
labels = labels
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
labels = labels
batch = {
"input_ids": input_ids,
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
return batch
def _shift_right_t5(self, input_ids):
decoder_start_token_id = self.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
@@ -267,17 +261,15 @@ def main():
use_task_specific_params(model, data_args.task)
# set num_beams for evaluation
if data_args.eval_beams is not None:
model.config.num_beams = data_args.eval_beams
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
# set max length for generation
model.config.max_generate_length = data_args.val_max_target_length
if data_args.eval_beams is None:
data_args.eval_beams = model.config.num_beams
# set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
model.config.decoder_start_token_id = decoder_start_token_id
assert (
data_args.tgt_lang is not None and data_args.src_lang is not None
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
@@ -293,32 +285,20 @@ def main():
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def freeze_embeds(model: torch.nn.Module):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
except AttributeError:
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
if model_args.freeze_embeds:
freeze_embeds(model)
if model_args.freeze_encoder:
@@ -376,6 +356,7 @@ def main():
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
data_args=data_args,
)
# Training
@@ -396,41 +377,36 @@ def main():
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
if trainer.is_world_process_zero():
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
with open(output_eval_file, "w") as f:
json.dump(result, f)
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
eval_results.update(result)
if training_args.do_predict:
logging.info("*** Test ***")
test_output = trainer.predict(test_dataset=test_dataset)
test_metrics = test_output.metrics
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
if trainer.is_world_process_zero():
logger.info("***** Test results *****")
for key, value in test_metrics.items():
logger.info(" %s = %s", key, value)
with open(output_test_file, "w") as f:
json.dump(test_metrics, f)
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
eval_results.update(test_metrics)
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
test_preds = tokenizer.batch_decode(
test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
test_preds = lmap(str.strip, test_preds)
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
with open(output_test_pred_file, "w") as f:
f.write("\n".join(test_preds))
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
if trainer.is_world_process_zero():
save_json(eval_results, "all_results.json")
return eval_results