[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -29,10 +28,13 @@ from utils import (
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
lmap,
|
||||
save_json,
|
||||
trim_batch,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
labels = labels
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
labels = labels
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
decoder_start_token_id = self.pad_token_id
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
@@ -267,17 +261,15 @@ def main():
|
||||
use_task_specific_params(model, data_args.task)
|
||||
|
||||
# set num_beams for evaluation
|
||||
if data_args.eval_beams is not None:
|
||||
model.config.num_beams = data_args.eval_beams
|
||||
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
|
||||
|
||||
# set max length for generation
|
||||
model.config.max_generate_length = data_args.val_max_target_length
|
||||
if data_args.eval_beams is None:
|
||||
data_args.eval_beams = model.config.num_beams
|
||||
|
||||
# set decoder_start_token_id for MBart
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
model.config.decoder_start_token_id = decoder_start_token_id
|
||||
assert (
|
||||
data_args.tgt_lang is not None and data_args.src_lang is not None
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
@@ -293,32 +285,20 @@ def main():
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
def freeze_embeds(model: torch.nn.Module):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
try:
|
||||
freeze_params(model.model.shared)
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
except AttributeError:
|
||||
freeze_params(model.shared)
|
||||
for d in [model.encoder, model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
if model_args.freeze_encoder:
|
||||
@@ -376,6 +356,7 @@ def main():
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -396,41 +377,36 @@ def main():
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_eval_file, "w") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||
eval_results.update(result)
|
||||
|
||||
if training_args.do_predict:
|
||||
logging.info("*** Test ***")
|
||||
|
||||
test_output = trainer.predict(test_dataset=test_dataset)
|
||||
test_metrics = test_output.metrics
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
|
||||
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in test_metrics.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_test_file, "w") as f:
|
||||
json.dump(test_metrics, f)
|
||||
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
eval_results.update(test_metrics)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
test_preds = lmap(str.strip, test_preds)
|
||||
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_pred_file, "w") as f:
|
||||
f.write("\n".join(test_preds))
|
||||
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
save_json(eval_results, "all_results.json")
|
||||
return eval_results
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user