[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -26,6 +26,7 @@ from utils import (
|
|||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
|
freeze_embeds,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
label_smoothed_nll_loss,
|
label_smoothed_nll_loss,
|
||||||
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||||
if self.hparams.freeze_embeds:
|
if self.hparams.freeze_embeds:
|
||||||
self.freeze_embeds()
|
freeze_embeds(self.model)
|
||||||
if self.hparams.freeze_encoder:
|
if self.hparams.freeze_encoder:
|
||||||
freeze_params(self.model.get_encoder())
|
freeze_params(self.model.get_encoder())
|
||||||
assert_all_frozen(self.model.get_encoder())
|
assert_all_frozen(self.model.get_encoder())
|
||||||
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||||
)
|
)
|
||||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||||
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
|
||||||
if self.hparams.eval_max_gen_length is not None:
|
if self.hparams.eval_max_gen_length is not None:
|
||||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||||
else:
|
else:
|
||||||
self.eval_max_length = self.model.config.max_length
|
self.eval_max_length = self.model.config.max_length
|
||||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||||
|
|
||||||
def freeze_embeds(self):
|
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
|
||||||
if self.model_type == "t5":
|
|
||||||
freeze_params(self.model.shared)
|
|
||||||
for d in [self.model.encoder, self.model.decoder]:
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
elif self.model_type == "fsmt":
|
|
||||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
|
||||||
freeze_params(d.embed_positions)
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
else:
|
|
||||||
freeze_params(self.model.model.shared)
|
|
||||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
|
||||||
freeze_params(d.embed_positions)
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
|
|
||||||
def forward(self, input_ids, **kwargs):
|
def forward(self, input_ids, **kwargs):
|
||||||
return self.model(input_ids, **kwargs)
|
return self.model(input_ids, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -29,10 +28,13 @@ from utils import (
|
|||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
|
freeze_embeds,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
lmap,
|
lmap,
|
||||||
|
save_json,
|
||||||
trim_batch,
|
trim_batch,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
|
write_txt_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
|
|||||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
|
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.tpu_num_cores = tpu_num_cores
|
self.tpu_num_cores = tpu_num_cores
|
||||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||||
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
|
|||||||
|
|
||||||
if isinstance(self.tokenizer, T5Tokenizer):
|
if isinstance(self.tokenizer, T5Tokenizer):
|
||||||
decoder_input_ids = self._shift_right_t5(labels)
|
decoder_input_ids = self._shift_right_t5(labels)
|
||||||
labels = labels
|
|
||||||
else:
|
else:
|
||||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||||
labels = labels
|
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _shift_right_t5(self, input_ids):
|
def _shift_right_t5(self, input_ids):
|
||||||
decoder_start_token_id = self.pad_token_id
|
|
||||||
|
|
||||||
assert (
|
|
||||||
decoder_start_token_id is not None
|
|
||||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
|
||||||
|
|
||||||
# shift inputs to the right
|
# shift inputs to the right
|
||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
shifted_input_ids[..., 0] = self.pad_token_id
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
@@ -267,17 +261,15 @@ def main():
|
|||||||
use_task_specific_params(model, data_args.task)
|
use_task_specific_params(model, data_args.task)
|
||||||
|
|
||||||
# set num_beams for evaluation
|
# set num_beams for evaluation
|
||||||
if data_args.eval_beams is not None:
|
if data_args.eval_beams is None:
|
||||||
model.config.num_beams = data_args.eval_beams
|
data_args.eval_beams = model.config.num_beams
|
||||||
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
|
|
||||||
|
|
||||||
# set max length for generation
|
|
||||||
model.config.max_generate_length = data_args.val_max_target_length
|
|
||||||
|
|
||||||
# set decoder_start_token_id for MBart
|
# set decoder_start_token_id for MBart
|
||||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||||
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
assert (
|
||||||
model.config.decoder_start_token_id = decoder_start_token_id
|
data_args.tgt_lang is not None and data_args.src_lang is not None
|
||||||
|
), "mBart requires --tgt_lang and --src_lang"
|
||||||
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||||
|
|
||||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||||
def non_pad_len(tokens: np.ndarray) -> int:
|
def non_pad_len(tokens: np.ndarray) -> int:
|
||||||
@@ -293,32 +285,20 @@ def main():
|
|||||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||||
pred_str, label_str = decode_pred(pred)
|
pred_str, label_str = decode_pred(pred)
|
||||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||||
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
|
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||||
rouge.update({"gen_len": summ_len})
|
rouge.update({"gen_len": summ_len})
|
||||||
return rouge
|
return rouge
|
||||||
|
|
||||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||||
pred_str, label_str = decode_pred(pred)
|
pred_str, label_str = decode_pred(pred)
|
||||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||||
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
|
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||||
bleu.update({"gen_len": gen_len})
|
bleu.update({"gen_len": gen_len})
|
||||||
return bleu
|
return bleu
|
||||||
|
|
||||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||||
return compute_metrics_fn
|
return compute_metrics_fn
|
||||||
|
|
||||||
def freeze_embeds(model: torch.nn.Module):
|
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
|
||||||
try:
|
|
||||||
freeze_params(model.model.shared)
|
|
||||||
for d in [model.model.encoder, model.model.decoder]:
|
|
||||||
freeze_params(d.embed_positions)
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
except AttributeError:
|
|
||||||
freeze_params(model.shared)
|
|
||||||
for d in [model.encoder, model.decoder]:
|
|
||||||
freeze_params(d.embed_tokens)
|
|
||||||
|
|
||||||
if model_args.freeze_embeds:
|
if model_args.freeze_embeds:
|
||||||
freeze_embeds(model)
|
freeze_embeds(model)
|
||||||
if model_args.freeze_encoder:
|
if model_args.freeze_encoder:
|
||||||
@@ -376,6 +356,7 @@ def main():
|
|||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||||
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
@@ -396,41 +377,36 @@ def main():
|
|||||||
|
|
||||||
result = trainer.evaluate()
|
result = trainer.evaluate()
|
||||||
|
|
||||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
logger.info("***** Eval results *****")
|
logger.info("***** Eval results *****")
|
||||||
for key, value in result.items():
|
for key, value in result.items():
|
||||||
logger.info(" %s = %s", key, value)
|
logger.info(" %s = %s", key, value)
|
||||||
|
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||||
with open(output_eval_file, "w") as f:
|
|
||||||
json.dump(result, f)
|
|
||||||
|
|
||||||
eval_results.update(result)
|
eval_results.update(result)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logging.info("*** Test ***")
|
logging.info("*** Test ***")
|
||||||
|
|
||||||
test_output = trainer.predict(test_dataset=test_dataset)
|
test_output = trainer.predict(test_dataset=test_dataset)
|
||||||
test_metrics = test_output.metrics
|
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
|
||||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
|
|
||||||
|
|
||||||
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
|
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
logger.info("***** Test results *****")
|
logger.info("***** Test results *****")
|
||||||
for key, value in test_metrics.items():
|
for key, value in test_metrics.items():
|
||||||
logger.info(" %s = %s", key, value)
|
logger.info(" %s = %s", key, value)
|
||||||
|
|
||||||
with open(output_test_file, "w") as f:
|
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||||
json.dump(test_metrics, f)
|
eval_results.update(test_metrics)
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
|
test_preds = tokenizer.batch_decode(
|
||||||
|
test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
test_preds = lmap(str.strip, test_preds)
|
test_preds = lmap(str.strip, test_preds)
|
||||||
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
|
||||||
with open(output_test_pred_file, "w") as f:
|
|
||||||
f.write("\n".join(test_preds))
|
|
||||||
|
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
save_json(eval_results, "all_results.json")
|
||||||
return eval_results
|
return eval_results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
|
def __init__(self, data_args, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.data_args = data_args
|
||||||
|
self.max_gen_length = data_args.val_max_target_length
|
||||||
|
self.pad_token_id = self.model.config.pad_token_id
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
return None
|
return None
|
||||||
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
outputs = model(**inputs, use_cache=False)
|
outputs = model(**inputs, use_cache=False)
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
|
return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
|
||||||
|
|
||||||
def _compute_loss(self, logits, labels, ignore_index):
|
def _compute_loss(self, logits, labels, ignore_index):
|
||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
max_length = (
|
|
||||||
model.config.max_generate_length
|
|
||||||
if hasattr(model.config, "max_generate_length")
|
|
||||||
else model.config.max_position_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||||
generated_tokens = model.generate(
|
generated_tokens = model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_beams=model.config.num_beams,
|
num_beams=self.data_args.eval_beams,
|
||||||
max_length=max_length,
|
max_length=self.max_gen_length,
|
||||||
)
|
)
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
generated_tokens = self._pad_tensors_to_max_len(
|
generated_tokens = self._pad_tensors_to_max_len(
|
||||||
generated_tokens, max_length, model.config.pad_token_id
|
generated_tokens, self.max_gen_length, self.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
labels_out = inputs.get("labels")
|
labels_out = inputs.get("labels")
|
||||||
outputs = model(**inputs)
|
# Call forward again to get loss # TODO: avoidable?
|
||||||
logits = outputs[1]
|
outputs = model(**inputs, use_cache=False)
|
||||||
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
|
loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
|
||||||
loss = loss.mean().item()
|
loss = loss.mean().item()
|
||||||
if self.args.prediction_loss_only:
|
|
||||||
logits = None
|
|
||||||
else:
|
|
||||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
|
||||||
|
|
||||||
if self.args.prediction_loss_only:
|
if self.args.prediction_loss_only:
|
||||||
return (loss, None, None)
|
return (loss, None, None)
|
||||||
|
|
||||||
|
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||||
|
|
||||||
labels_out = labels_out.detach()
|
labels_out = labels_out.detach()
|
||||||
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
|
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
|
||||||
return (loss, logits.detach(), labels)
|
return (loss, logits.detach(), labels)
|
||||||
|
|
||||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||||
|
|||||||
@@ -3,36 +3,54 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .finetune_trainer import main
|
from .finetune_trainer import main
|
||||||
from .test_seq2seq_examples import MBART_TINY
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
from .utils import load_json
|
from .utils import load_json
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME = MBART_TINY
|
set_seed(42)
|
||||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
|
||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_model_download():
|
|
||||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
|
||||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
|
||||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_finetune_trainer():
|
def test_finetune_trainer():
|
||||||
|
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||||
|
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||||
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
|
first_step_stats = eval_metrics[0]
|
||||||
|
assert "eval_bleu" in first_step_stats
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_finetune_trainer_slow():
|
||||||
|
# TODO(SS): This will fail on devices with more than 1 GPU.
|
||||||
|
# There is a missing call to __init__process_group somewhere
|
||||||
|
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
|
# Check metrics
|
||||||
|
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||||
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
|
first_step_stats = eval_metrics[0]
|
||||||
|
last_step_stats = eval_metrics[-1]
|
||||||
|
|
||||||
|
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||||
|
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||||
|
|
||||||
|
# test if do_predict saves generations and metrics
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
assert "test_generations.txt" in contents
|
||||||
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
|
|
||||||
|
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
output_dir = tempfile.mkdtemp(prefix="test_output")
|
||||||
max_len = "128"
|
|
||||||
num_train_epochs = 4
|
|
||||||
eval_steps = 2
|
|
||||||
argv = [
|
argv = [
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
MARIAN_MODEL,
|
model_name,
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
data_dir,
|
data_dir,
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
@@ -72,25 +90,17 @@ def test_finetune_trainer():
|
|||||||
"--sortish_sampler",
|
"--sortish_sampler",
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
"0.1",
|
"0.1",
|
||||||
|
# "--eval_beams",
|
||||||
|
# "2",
|
||||||
"--task",
|
"--task",
|
||||||
"translation",
|
"translation",
|
||||||
|
"--tgt_lang",
|
||||||
|
"ro_RO",
|
||||||
|
"--src_lang",
|
||||||
|
"en_XX",
|
||||||
]
|
]
|
||||||
|
|
||||||
testargs = ["finetune_trainer.py"] + argv
|
testargs = ["finetune_trainer.py"] + argv
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
main()
|
main()
|
||||||
|
|
||||||
# Check metrics
|
return output_dir
|
||||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
|
||||||
first_step_stats = eval_metrics[0]
|
|
||||||
last_step_stats = eval_metrics[-1]
|
|
||||||
|
|
||||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
|
||||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
|
||||||
|
|
||||||
# test if do_predict saves generations and metrics
|
|
||||||
contents = os.listdir(output_dir)
|
|
||||||
contents = {os.path.basename(p) for p in contents}
|
|
||||||
assert "test_generations.txt" in contents
|
|
||||||
assert "test_results.json" in contents
|
|
||||||
|
|||||||
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
|
|||||||
par.requires_grad = False
|
par.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_embeds(model):
|
||||||
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
|
model_type = model.config.model_type
|
||||||
|
|
||||||
|
if model_type == "t5":
|
||||||
|
freeze_params(model.shared)
|
||||||
|
for d in [model.encoder, model.decoder]:
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
elif model_type == "fsmt":
|
||||||
|
for d in [model.model.encoder, model.model.decoder]:
|
||||||
|
freeze_params(d.embed_positions)
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
else:
|
||||||
|
freeze_params(model.model.shared)
|
||||||
|
for d in [model.model.encoder, model.model.decoder]:
|
||||||
|
freeze_params(d.embed_positions)
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
|
||||||
|
|
||||||
def grad_status(model: nn.Module) -> Iterable:
|
def grad_status(model: nn.Module) -> Iterable:
|
||||||
return (par.requires_grad for par in model.parameters())
|
return (par.requires_grad for par in model.parameters())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user