Upgrade examples to pl=0.8.1(#5146)
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AdamW,
|
AdamW,
|
||||||
@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
model=None,
|
model=None,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
):
|
):
|
||||||
"Initialize a model."
|
"""Initialize a model, tokenizer and config."""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hparams = hparams
|
self.hparams = hparams # TODO: move to self.save_hyperparameters()
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.tfmr_ckpts = {}
|
self.tfmr_ckpts = {}
|
||||||
self.output_dir = Path(self.hparams.output_dir)
|
self.output_dir = Path(self.hparams.output_dir)
|
||||||
@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||||
if model is None:
|
|
||||||
self.model_type = MODEL_MODES[mode]
|
self.model_type = MODEL_MODES[mode]
|
||||||
|
if model is None:
|
||||||
self.model = self.model_type.from_pretrained(
|
self.model = self.model_type.from_pretrained(
|
||||||
self.hparams.model_name_or_path,
|
self.hparams.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||||
@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model_type = None
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def load_hf_checkpoint(self, *args, **kwargs):
|
def load_hf_checkpoint(self, *args, **kwargs):
|
||||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
def is_logger(self):
|
|
||||||
return self.trainer.proc_rank <= 0
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
return [optimizer]
|
return [optimizer]
|
||||||
|
|
||||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
|
||||||
if self.trainer.use_tpu:
|
|
||||||
xm.optimizer_step(optimizer)
|
|
||||||
else:
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
def get_tqdm_dict(self):
|
|
||||||
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
|
|
||||||
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
|
||||||
return tqdm_dict
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_nb):
|
def test_step(self, batch, batch_nb):
|
||||||
return self.validation_step(batch, batch_nb)
|
return self.validation_step(batch, batch_nb)
|
||||||
|
|
||||||
def test_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
return self.validation_end(outputs)
|
return self.validation_end(outputs)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||||
)
|
)
|
||||||
@@ -217,21 +200,19 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
|
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
|
@rank_zero_only
|
||||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
logger.info("***** Validation results *****")
|
rank_zero_info("***** Validation results *****")
|
||||||
if pl_module.is_logger():
|
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
# Log results
|
# Log results
|
||||||
for key in sorted(metrics):
|
for key in sorted(metrics):
|
||||||
if key not in ["log", "progress_bar"]:
|
if key not in ["log", "progress_bar"]:
|
||||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
logger.info("***** Test results *****")
|
logger.info("***** Test results *****")
|
||||||
|
|
||||||
if pl_module.is_logger():
|
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
|
|
||||||
# Log and save results to file
|
# Log and save results to file
|
||||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||||
with open(output_test_results_file, "w") as writer:
|
with open(output_test_results_file, "w") as writer:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ psutil
|
|||||||
sacrebleu
|
sacrebleu
|
||||||
rouge-score
|
rouge-score
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
pytorch-lightning==0.7.6
|
pytorch-lightning==0.8.1
|
||||||
matplotlib
|
matplotlib
|
||||||
git-python==1.0.3
|
git-python==1.0.3
|
||||||
faiss
|
faiss
|
||||||
|
|||||||
@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqLoggingCallback(pl.Callback):
|
class Seq2SeqLoggingCallback(pl.Callback):
|
||||||
|
@rank_zero_only
|
||||||
def _write_logs(
|
def _write_logs(
|
||||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||||
if not pl_module.is_logger():
|
|
||||||
return
|
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||||
# Log results
|
# Log results
|
||||||
|
|||||||
@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
|
|
||||||
class T5SummarizationDistiller(SummarizationDistiller):
|
class T5SummarizationDistiller(SummarizationDistiller):
|
||||||
def pre_init(self, hparams):
|
def pre_init(self, hparams):
|
||||||
|
raise NotImplementedError("T5 Distillation does not work yet")
|
||||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||||
n_layer = hparams.student_decoder_layers
|
n_layer = hparams.student_decoder_layers
|
||||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
if self.hparams.freeze_encoder:
|
if self.hparams.freeze_encoder:
|
||||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
|
self.num_workers = hparams.num_workers
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
def validation_step(self, batch, batch_idx) -> Dict:
|
def validation_step(self, batch, batch_idx) -> Dict:
|
||||||
return self._generative_step(batch)
|
return self._generative_step(batch)
|
||||||
|
|
||||||
def validation_end(self, outputs, prefix="val") -> Dict:
|
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.metrics[prefix].append(metrics)
|
self.metrics[prefix].append(metrics)
|
||||||
pickle_save(self.metrics, self.metrics_save_path)
|
pickle_save(self.metrics, self.metrics_save_path)
|
||||||
|
|
||||||
def _generative_step(self, batch):
|
def _generative_step(self, batch: dict) -> dict:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
# TODO(SS): task specific params
|
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||||
gen_time = time.time() - t0
|
gen_time = time.time() - t0 / source_ids.shape[0]
|
||||||
preds = self.ids_to_clean_text(generated_ids)
|
preds = self.ids_to_clean_text(generated_ids)
|
||||||
target = self.ids_to_clean_text(y)
|
target = self.ids_to_clean_text(y)
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
return self._generative_step(batch)
|
return self._generative_step(batch)
|
||||||
|
|
||||||
def test_end(self, outputs):
|
|
||||||
return self.validation_end(outputs, prefix="test")
|
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
return self.validation_epoch_end(outputs, prefix="test")
|
||||||
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
|
|
||||||
# write predictions and targets for later rouge evaluation.
|
|
||||||
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
|
|
||||||
for output_batch in outputs:
|
|
||||||
p_writer.writelines(s + "\n" for s in output_batch["preds"])
|
|
||||||
t_writer.writelines(s + "\n" for s in output_batch["target"])
|
|
||||||
p_writer.close()
|
|
||||||
t_writer.close()
|
|
||||||
|
|
||||||
return self.test_end(outputs)
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
|
||||||
self.validation_end(outputs, "val")
|
|
||||||
|
|
||||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||||
n_obs = self.n_obs[type_path]
|
n_obs = self.n_obs[type_path]
|
||||||
@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
# TODO: early stopping callback seems messed up
|
# TODO: early stopping callback seems messed up
|
||||||
)
|
)
|
||||||
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||||
if not args.do_predict:
|
if not args.do_predict:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,5 @@ python distillation.py \
|
|||||||
--learning_rate=3e-4 \
|
--learning_rate=3e-4 \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--fp16 \
|
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.1 \
|
||||||
$@
|
$@
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ def generate_summaries(
|
|||||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
fout = Path(out_file).open("w", encoding="utf-8")
|
fout = Path(out_file).open("w", encoding="utf-8")
|
||||||
|
model_name = str(model_name)
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ logger = logging.getLogger()
|
|||||||
FP16_EVER = False
|
FP16_EVER = False
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"logger": "default",
|
"logger": "default",
|
||||||
|
"num_workers": 2,
|
||||||
"alpha_hid": 0,
|
"alpha_hid": 0,
|
||||||
"freeze_embeds": True,
|
"freeze_embeds": True,
|
||||||
"enc_only": False,
|
"enc_only": False,
|
||||||
@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
|
|||||||
f.write("\n".join(articles))
|
f.write("\n".join(articles))
|
||||||
|
|
||||||
|
|
||||||
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
|
MSG = "T5 is broken at the moment"
|
||||||
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir():
|
def make_test_data_dir():
|
||||||
@@ -92,7 +94,6 @@ def make_test_data_dir():
|
|||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
|
|
||||||
class TestSummarizationDistiller(unittest.TestCase):
|
class TestSummarizationDistiller(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
freeze_encoder=True,
|
freeze_encoder=True,
|
||||||
gpus=2,
|
gpus=2,
|
||||||
sortish_sampler=False,
|
sortish_sampler=False,
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
|
||||||
def test_bdc_fp16(self):
|
|
||||||
updates = dict(
|
|
||||||
student_encoder_layers=2,
|
|
||||||
student_decoder_layers=1,
|
|
||||||
alpha_hid=3.0,
|
|
||||||
freeze_encoder=True,
|
|
||||||
gpus=1,
|
|
||||||
fp16=FP16_EVER,
|
|
||||||
fp16_opt_level="O1",
|
fp16_opt_level="O1",
|
||||||
|
fp16=FP16_EVER,
|
||||||
)
|
)
|
||||||
self._bart_distiller_cli(updates)
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
def test_bdc_t5_train(self):
|
||||||
def test_bdc_t5_eval_fp16(self):
|
|
||||||
updates = dict(
|
updates = dict(
|
||||||
fp16=FP16_EVER,
|
fp16=FP16_EVER,
|
||||||
gpus=1,
|
gpus=1 if torch.cuda.is_available() else 0,
|
||||||
model_type="t5",
|
model_type="t5",
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
model_name_or_path=T5_TINY,
|
||||||
do_train=False,
|
|
||||||
do_predict=True,
|
|
||||||
tokenizer_name=None,
|
|
||||||
no_teacher=True,
|
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates, check_contents=False)
|
|
||||||
|
|
||||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
|
||||||
def test_bdc_t5_train_fp16(self):
|
|
||||||
updates = dict(
|
|
||||||
fp16=FP16_EVER,
|
|
||||||
gpus=1,
|
|
||||||
model_type="t5",
|
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
|
||||||
do_train=True,
|
do_train=True,
|
||||||
do_predict=True,
|
do_predict=True,
|
||||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
tokenizer_name=T5_TINY,
|
||||||
no_teacher=True,
|
no_teacher=True,
|
||||||
|
alpha_hid=2.0,
|
||||||
)
|
)
|
||||||
self._bart_distiller_cli(updates)
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self._bart_distiller_cli(updates)
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
def test_bdc_checkpointing(self):
|
def test_bdc_checkpointing(self):
|
||||||
|
|
||||||
updates = dict(
|
updates = dict(
|
||||||
student_encoder_layers=2,
|
student_encoder_layers=2,
|
||||||
student_decoder_layers=1,
|
student_decoder_layers=1,
|
||||||
@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
def test_bdc_t5(self):
|
|
||||||
updates = dict(
|
|
||||||
student_encoder_layers=1,
|
|
||||||
student_decoder_layers=1,
|
|
||||||
alpha_hid=2.0,
|
|
||||||
teacher="patrickvonplaten/t5-tiny-random",
|
|
||||||
model_type="t5",
|
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
|
||||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
def test_bdc_t5_eval(self):
|
|
||||||
updates = dict(
|
|
||||||
model_type="t5",
|
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
|
||||||
do_train=False,
|
|
||||||
do_predict=True,
|
|
||||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
|
||||||
no_teacher=True,
|
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates, check_contents=False)
|
|
||||||
|
|
||||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||||
default_updates = dict(
|
default_updates = dict(
|
||||||
model_type="bart",
|
|
||||||
train_batch_size=1,
|
train_batch_size=1,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
num_train_epochs=2,
|
num_train_epochs=2,
|
||||||
@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self.assertIn(ckpt_name, contents)
|
self.assertIn(ckpt_name, contents)
|
||||||
self.assertIn("metrics.pkl", contents)
|
self.assertIn("metrics.pkl", contents)
|
||||||
self.assertIn("test_generations.txt", contents)
|
self.assertIn("test_generations.txt", contents)
|
||||||
self.assertIn("val_generations_1.txt", contents)
|
self.assertIn("val_generations_00001.txt", contents)
|
||||||
self.assertIn("val_1_results.txt", contents)
|
self.assertIn("val_results_00001.txt", contents)
|
||||||
self.assertIn("test_results.txt", contents)
|
self.assertIn("test_results.txt", contents)
|
||||||
# self.assertEqual(len(contents), 15)
|
|
||||||
|
|
||||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||||
import pandas as pd
|
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||||
|
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||||
val_df = pd.DataFrame(metrics["val"])
|
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
|
||||||
train_df = pd.DataFrame(metrics["train"])
|
|
||||||
test_df = pd.DataFrame(metrics["test"])
|
|
||||||
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
|
|
||||||
self.assertEqual(val_df.shape[0], desired_n_evals) #
|
|
||||||
self.assertEqual(test_df.shape[1], val_df.shape[1])
|
|
||||||
self.assertEqual(train_df.shape[0], 0)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
model_type="t5",
|
model_name_or_path=T5_TINY,
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
tokenizer_name=None, # T5_TINY,
|
||||||
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
|
|
||||||
train_batch_size=2,
|
train_batch_size=2,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
gpus=0,
|
gpus=0,
|
||||||
|
|||||||
@@ -45,8 +45,10 @@ def encode_file(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
add_prefix_space=True,
|
add_prefix_space=True,
|
||||||
|
truncation=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
|
assert tokenized.input_ids.shape[1] == max_length
|
||||||
examples.append(tokenized)
|
examples.append(tokenized)
|
||||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||||
return examples
|
return examples
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
|
|||||||
|
|
||||||
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
|
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
|
||||||
|
|
||||||
def _eval_end(self, outputs):
|
def _eval_end(self, outputs) -> tuple:
|
||||||
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
|
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
|
||||||
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
|
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
|
||||||
|
|
||||||
@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
|
|||||||
logs = ret["log"]
|
logs = ret["log"]
|
||||||
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs) -> dict:
|
||||||
# updating to test_epoch_end instead of deprecated test_end
|
|
||||||
ret, predictions, targets = self._eval_end(outputs)
|
ret, predictions, targets = self._eval_end(outputs)
|
||||||
|
|
||||||
# Converting to the dic required by pl
|
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
|
|
||||||
# pytorch_lightning/trainer/logging.py#L139
|
|
||||||
logs = ret["log"]
|
logs = ret["log"]
|
||||||
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
|
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
|
||||||
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
# Add NER specific options
|
|
||||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class AutoTokenizer:
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
if "bert-base-japanese" in str(pretrained_model_name_or_path):
|
||||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
use_fast = kwargs.pop("use_fast", False)
|
use_fast = kwargs.pop("use_fast", False)
|
||||||
|
|||||||
Reference in New Issue
Block a user