[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -26,6 +26,7 @@ from utils import (
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
freeze_embeds(self.model)
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
|
||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
)
|
||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
||||
if self.hparams.eval_max_gen_length is not None:
|
||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||
else:
|
||||
self.eval_max_length = self.model.config.max_length
|
||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
if self.model_type == "t5":
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
elif self.model_type == "fsmt":
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -29,10 +28,13 @@ from utils import (
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
lmap,
|
||||
save_json,
|
||||
trim_batch,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
labels = labels
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
labels = labels
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
decoder_start_token_id = self.pad_token_id
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
@@ -267,17 +261,15 @@ def main():
|
||||
use_task_specific_params(model, data_args.task)
|
||||
|
||||
# set num_beams for evaluation
|
||||
if data_args.eval_beams is not None:
|
||||
model.config.num_beams = data_args.eval_beams
|
||||
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
|
||||
|
||||
# set max length for generation
|
||||
model.config.max_generate_length = data_args.val_max_target_length
|
||||
if data_args.eval_beams is None:
|
||||
data_args.eval_beams = model.config.num_beams
|
||||
|
||||
# set decoder_start_token_id for MBart
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
model.config.decoder_start_token_id = decoder_start_token_id
|
||||
assert (
|
||||
data_args.tgt_lang is not None and data_args.src_lang is not None
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
@@ -293,32 +285,20 @@ def main():
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
def freeze_embeds(model: torch.nn.Module):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
try:
|
||||
freeze_params(model.model.shared)
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
except AttributeError:
|
||||
freeze_params(model.shared)
|
||||
for d in [model.encoder, model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
if model_args.freeze_encoder:
|
||||
@@ -376,6 +356,7 @@ def main():
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -396,41 +377,36 @@ def main():
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_eval_file, "w") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||
eval_results.update(result)
|
||||
|
||||
if training_args.do_predict:
|
||||
logging.info("*** Test ***")
|
||||
|
||||
test_output = trainer.predict(test_dataset=test_dataset)
|
||||
test_metrics = test_output.metrics
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
|
||||
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in test_metrics.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_test_file, "w") as f:
|
||||
json.dump(test_metrics, f)
|
||||
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
eval_results.update(test_metrics)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
test_preds = lmap(str.strip, test_preds)
|
||||
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_pred_file, "w") as f:
|
||||
f.write("\n".join(test_preds))
|
||||
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
save_json(eval_results, "all_results.json")
|
||||
return eval_results
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, data_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.pad_token_id = self.model.config.pad_token_id
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
logits = outputs[0]
|
||||
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
|
||||
return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
|
||||
|
||||
def _compute_loss(self, logits, labels, ignore_index):
|
||||
if self.args.label_smoothing == 0:
|
||||
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
|
||||
"""
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
max_length = (
|
||||
model.config.max_generate_length
|
||||
if hasattr(model.config, "max_generate_length")
|
||||
else model.config.max_position_embeddings
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||
generated_tokens = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
use_cache=True,
|
||||
num_beams=model.config.num_beams,
|
||||
max_length=max_length,
|
||||
num_beams=self.data_args.eval_beams,
|
||||
max_length=self.max_gen_length,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
generated_tokens = self._pad_tensors_to_max_len(
|
||||
generated_tokens, max_length, model.config.pad_token_id
|
||||
generated_tokens, self.max_gen_length, self.pad_token_id
|
||||
)
|
||||
|
||||
labels_out = inputs.get("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[1]
|
||||
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
|
||||
# Call forward again to get loss # TODO: avoidable?
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
|
||||
loss = loss.mean().item()
|
||||
if self.args.prediction_loss_only:
|
||||
logits = None
|
||||
else:
|
||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||
|
||||
labels_out = labels_out.detach()
|
||||
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
|
||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||
|
||||
@@ -3,36 +3,54 @@ import sys
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
MODEL_NAME = MBART_TINY
|
||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||
set_seed(42)
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
@slow
|
||||
def test_model_download():
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer():
|
||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow():
|
||||
# TODO(SS): This will fail on devices with more than 1 GPU.
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
|
||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||
max_len = "128"
|
||||
num_train_epochs = 4
|
||||
eval_steps = 2
|
||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
||||
argv = [
|
||||
"--model_name_or_path",
|
||||
MARIAN_MODEL,
|
||||
model_name,
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--output_dir",
|
||||
@@ -72,25 +90,17 @@ def test_finetune_trainer():
|
||||
"--sortish_sampler",
|
||||
"--label_smoothing",
|
||||
"0.1",
|
||||
# "--eval_beams",
|
||||
# "2",
|
||||
"--task",
|
||||
"translation",
|
||||
"--tgt_lang",
|
||||
"ro_RO",
|
||||
"--src_lang",
|
||||
"en_XX",
|
||||
]
|
||||
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
return output_dir
|
||||
|
||||
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
|
||||
par.requires_grad = False
|
||||
|
||||
|
||||
def freeze_embeds(model):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
model_type = model.config.model_type
|
||||
|
||||
if model_type == "t5":
|
||||
freeze_params(model.shared)
|
||||
for d in [model.encoder, model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
elif model_type == "fsmt":
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
freeze_params(model.model.shared)
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
|
||||
def grad_status(model: nn.Module) -> Iterable:
|
||||
return (par.requires_grad for par in model.parameters())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user