examples/seq2seq supports translation (#5202)
This commit is contained in:
169
examples/seq2seq/README.md
Normal file
169
examples/seq2seq/README.md
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
|
||||||
|
Summarization support is more mature than translation support.
|
||||||
|
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
|
||||||
|
For `bertabs` instructions, see `bertabs/README.md`.
|
||||||
|
|
||||||
|
### Data
|
||||||
|
|
||||||
|
CNN/DailyMail data
|
||||||
|
```bash
|
||||||
|
cd examples/seq2seq
|
||||||
|
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||||
|
tar -xzvf cnn_dm.tgz
|
||||||
|
|
||||||
|
export CNN_DIR=${PWD}/cnn_dm
|
||||||
|
```
|
||||||
|
|
||||||
|
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||||
|
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||||
|
|
||||||
|
XSUM Data:
|
||||||
|
```bash
|
||||||
|
cd examples/seq2seq
|
||||||
|
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||||
|
tar -xzvf xsum.tar.gz
|
||||||
|
export XSUM_DIR=${PWD}/xsum
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
WMT16 English-Romanian Translation Data:
|
||||||
|
```bash
|
||||||
|
cd examples/seq2seq
|
||||||
|
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
||||||
|
tar -xzvf wmt_en_ro.tar.gz
|
||||||
|
export ENRO_DIR=${PWD}/wmt_en_ro
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||||
|
The `.source` files are the input, the `.target` files are the desired output.
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
To create summaries for each article in dataset, run:
|
||||||
|
```bash
|
||||||
|
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||||
|
```
|
||||||
|
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||||
|
|
||||||
|
|
||||||
|
### Summarization Finetuning
|
||||||
|
Run/modify `finetune.sh`
|
||||||
|
|
||||||
|
The following command should work on a 16GB GPU:
|
||||||
|
```bash
|
||||||
|
./finetune.sh \
|
||||||
|
--data_dir $XSUM_DIR \
|
||||||
|
--train_batch_size=1 \
|
||||||
|
--eval_batch_size=1 \
|
||||||
|
--output_dir=xsum_results \
|
||||||
|
--num_train_epochs 1 \
|
||||||
|
--model_name_or_path facebook/bart-large
|
||||||
|
```
|
||||||
|
|
||||||
|
*Note*: The following tips mostly apply to summarization finetuning.
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||||
|
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||||
|
- `fp16_opt_level=O1` (the default works best).
|
||||||
|
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||||
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||||
|
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||||
|
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||||
|
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||||
|
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||||
|
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||||
|
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
|
||||||
|
- This warning can be safely ignored:
|
||||||
|
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||||
|
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||||
|
|
||||||
|
#### Finetuning Outputs
|
||||||
|
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||||
|
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
output_dir
|
||||||
|
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
|
||||||
|
│ ├── config.json
|
||||||
|
│ ├── merges.txt
|
||||||
|
│ ├── pytorch_model.bin
|
||||||
|
│ ├── special_tokens_map.json
|
||||||
|
│ ├── tokenizer_config.json
|
||||||
|
│ └── vocab.json
|
||||||
|
├── git_log.json # repo, branch, and commit hash
|
||||||
|
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
|
||||||
|
├── metrics.json # new validation metrics will continually be appended to this
|
||||||
|
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||||
|
│ ├── config.json
|
||||||
|
│ └── pytorch_model.bin
|
||||||
|
├── test_generations.txt
|
||||||
|
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
|
||||||
|
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
|
||||||
|
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
|
||||||
|
```
|
||||||
|
After training, you can recover the best checkpoint by running
|
||||||
|
```python
|
||||||
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### XSUM Shared Task
|
||||||
|
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||||
|
|
||||||
|
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||||
|
```bash
|
||||||
|
./finetune.sh \
|
||||||
|
--data_dir $XSUM_DIR \
|
||||||
|
--output_dir xsum_frozen_embs \
|
||||||
|
--model_name_or_path facebook/bart-large \
|
||||||
|
--logger wandb_shared \
|
||||||
|
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||||
|
--num_train_epochs 6 \
|
||||||
|
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
|
||||||
|
```
|
||||||
|
|
||||||
|
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||||
|
|
||||||
|
|
||||||
|
### Distilbart
|
||||||
|
|
||||||
|
#### No Teacher Distillation
|
||||||
|
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
|
||||||
|
You don't even need `distillation.py`.
|
||||||
|
|
||||||
|
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
|
||||||
|
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
|
||||||
|
The command that produced `sshleifer/distilbart-cnn-12-6` is
|
||||||
|
```bash
|
||||||
|
./train_distilbart_cnn.sh
|
||||||
|
```
|
||||||
|
runtime: 6H on NVIDIA RTX 24GB GPU
|
||||||
|
|
||||||
|
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
|
||||||
|
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
||||||
|
because you will have the same hyperparameters logged in every run.
|
||||||
|
|
||||||
|
#### With a teacher
|
||||||
|
*Note* only BART variants are supported
|
||||||
|
|
||||||
|
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
||||||
|
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
|
||||||
|
|
||||||
|
The command that produced `sshleifer/distilbart-xsum-12-6` is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./train_distilbart_xsum.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
runtime: 13H on V-100 16GB GPU.
|
||||||
|
|
||||||
|
### Contributing
|
||||||
|
- follow the standard contributing guidelines and code of conduct.
|
||||||
|
- add tests to `test_seq2seq_examples.py`
|
||||||
|
- To run only the seq2seq tests, you must be in the root of the repository and run:
|
||||||
|
```bash
|
||||||
|
pytest examples/seq2seq/
|
||||||
|
```
|
||||||
@@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati
|
|||||||
git clone https://github.com/huggingface/transformers && cd transformers
|
git clone https://github.com/huggingface/transformers && cd transformers
|
||||||
pip install .
|
pip install .
|
||||||
pip install nltk py-rouge
|
pip install nltk py-rouge
|
||||||
cd examples/summarization
|
cd examples/seq2seq/bertabs
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reproduce the authors' ROUGE score
|
## Reproduce the authors' ROUGE score
|
||||||
@@ -32,9 +32,12 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
|||||||
results_file = od / "test_results.txt"
|
results_file = od / "test_results.txt"
|
||||||
generations_file = od / "test_generations.txt"
|
generations_file = od / "test_generations.txt"
|
||||||
else:
|
else:
|
||||||
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
|
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||||
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
|
# If people want this it will be easy enough to add back.
|
||||||
|
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||||
|
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||||
|
results_file.parent.mkdir(exist_ok=True)
|
||||||
|
generations_file.parent.mkdir(exist_ok=True)
|
||||||
with open(results_file, "a+") as writer:
|
with open(results_file, "a+") as writer:
|
||||||
for key in sorted(metrics):
|
for key in sorted(metrics):
|
||||||
if key in ["log", "progress_bar", "preds"]:
|
if key in ["log", "progress_bar", "preds"]:
|
||||||
@@ -63,20 +66,25 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
|||||||
# mp stands for million parameters
|
# mp stands for million parameters
|
||||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
|
||||||
return self._write_logs(trainer, pl_module, "val")
|
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
return self._write_logs(trainer, pl_module, "test")
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
|
||||||
def get_rouge2_checkpoint_callback(output_dir):
|
def get_checkpoint_callback(output_dir, metric):
|
||||||
"""Saves the best model by validation ROUGE2 score."""
|
"""Saves the best model by validation ROUGE2 score."""
|
||||||
|
if metric == "rouge2":
|
||||||
|
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||||
|
elif metric == "bleu":
|
||||||
|
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||||
|
)
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
|
filepath=os.path.join(output_dir, exp),
|
||||||
monitor="val_rouge",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="max",
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
@@ -39,13 +39,12 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
assert Path(hparams.data_dir).exists()
|
assert Path(hparams.data_dir).exists()
|
||||||
|
student, student_cfg, teacher = self.pre_init(hparams)
|
||||||
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
|
|
||||||
|
|
||||||
super().__init__(hparams, model=student, config=student_cfg)
|
super().__init__(hparams, model=student, config=student_cfg)
|
||||||
self.teacher = teacher
|
self.teacher = teacher
|
||||||
@@ -73,12 +72,15 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
del self.teacher.model.encoder
|
del self.teacher.model.encoder
|
||||||
|
|
||||||
def pre_init(self, hparams):
|
def pre_init(self, hparams):
|
||||||
# Dump empty student model at a path, then call from_pretrained on it
|
self.output_dir = Path(hparams.output_dir)
|
||||||
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||||
student_updates = {
|
student_updates = {
|
||||||
"decoder_layers": hparams.student_decoder_layers,
|
"decoder_layers": hparams.student_decoder_layers,
|
||||||
"encoder_layers": hparams.student_encoder_layers,
|
"encoder_layers": hparams.student_encoder_layers,
|
||||||
}
|
}
|
||||||
|
if hparams.length_penalty != -1:
|
||||||
|
student_updates["length_penalty"] = hparams.length_penalty
|
||||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||||
hparams.d_layer_to_copy = d_layers_to_copy
|
hparams.d_layer_to_copy = d_layers_to_copy
|
||||||
@@ -89,9 +91,13 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
student_cfg = BartConfig(**kw)
|
student_cfg = BartConfig(**kw)
|
||||||
student = BartForConditionalGeneration(student_cfg)
|
student = BartForConditionalGeneration(student_cfg)
|
||||||
student, _ = init_student(student, teacher)
|
student, _ = init_student(student, teacher)
|
||||||
|
save_dir = self.output_dir.joinpath("student")
|
||||||
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
student.save_pretrained(save_dir)
|
||||||
return d_layers_to_copy, student, student_cfg, teacher
|
hparams.model_name_or_path = str(save_dir)
|
||||||
|
return student, student_cfg, teacher
|
||||||
|
|
||||||
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||||
if teacher.config.model_type == "t5":
|
if teacher.config.model_type == "t5":
|
||||||
@@ -154,7 +160,6 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
@@ -180,18 +185,11 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||||
parser.add_argument(
|
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||||
"--student_decoder_layers", default=12, type=int, required=False,
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||||
)
|
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||||
parser.add_argument(
|
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||||
"--student_encoder_layers", default=12, type=int, required=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_teacher", action="store_true", default=False,
|
|
||||||
)
|
|
||||||
parser.add_argument( # TODO: remove
|
|
||||||
"--enc_only", action="store_true", default=False,
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch):
|
||||||
@@ -269,12 +267,14 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
return sum(hidden_losses)
|
return sum(hidden_losses)
|
||||||
|
|
||||||
|
|
||||||
class T5SummarizationDistiller(SummarizationDistiller):
|
class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||||
def pre_init(self, hparams):
|
def pre_init(self, hparams):
|
||||||
raise NotImplementedError("T5 Distillation does not work yet")
|
raise NotImplementedError("T5 Distillation does not work yet")
|
||||||
|
self.output_dir = Path(hparams.output_dir)
|
||||||
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||||
n_layer = hparams.student_decoder_layers
|
n_layer = hparams.student_decoder_layers
|
||||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
|
||||||
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
||||||
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
||||||
student_updates = {"num_layers": n_layer}
|
student_updates = {"num_layers": n_layer}
|
||||||
@@ -291,8 +291,13 @@ class T5SummarizationDistiller(SummarizationDistiller):
|
|||||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||||
task_specific_params = student.config.task_specific_params
|
task_specific_params = student.config.task_specific_params
|
||||||
if task_specific_params is not None:
|
if task_specific_params is not None:
|
||||||
student.config.update(task_specific_params.get("summarization", {}))
|
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
|
||||||
return d_layers_to_copy, student, student_cfg, teacher
|
save_dir = self.output_dir.joinpath("student")
|
||||||
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
student.save_pretrained(save_dir)
|
||||||
|
hparams.model_name_or_path = str(save_dir)
|
||||||
|
return student, student_cfg, teacher
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
freeze_params(self.model.shared)
|
freeze_params(self.model.shared)
|
||||||
@@ -386,7 +391,7 @@ def create_module(args):
|
|||||||
elif args.enc_only:
|
elif args.enc_only:
|
||||||
raise ValueError("Deleted that")
|
raise ValueError("Deleted that")
|
||||||
else:
|
else:
|
||||||
module_cls = SummarizationDistiller
|
module_cls = BartSummarizationDistiller
|
||||||
args.setup_cls: str = module_cls.__name__
|
args.setup_cls: str = module_cls.__name__
|
||||||
model = module_cls(args)
|
model = module_cls(args)
|
||||||
return model
|
return model
|
||||||
@@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
|||||||
def get_layers_to_copy(n_to_get, tot):
|
def get_layers_to_copy(n_to_get, tot):
|
||||||
all_layers = list(range(tot))
|
all_layers = list(range(tot))
|
||||||
if tot == 12: # Alternating for special cases
|
if tot == 12: # Alternating for special cases
|
||||||
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
|
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||||
6: [0, 2, 4, 7, 9, 11],
|
1: [0],
|
||||||
1: [11],
|
2: [0, 6],
|
||||||
3: [0, 6, 11],
|
3: [0, 6, 11],
|
||||||
2: [0, 11],
|
|
||||||
4: [0, 4, 8, 11],
|
4: [0, 4, 8, 11],
|
||||||
|
6: [0, 2, 4, 7, 9, 11],
|
||||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||||
12: all_layers,
|
12: all_layers,
|
||||||
}
|
}
|
||||||
return layers_to_copy[n_to_get]
|
return layers_to_copy[n_to_get]
|
||||||
else:
|
else:
|
||||||
return all_layers[:n_to_get]
|
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
||||||
|
|
||||||
|
|
||||||
def distill_main(args):
|
def distill_main(args):
|
||||||
@@ -443,7 +448,7 @@ def distill_main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
distill_main(args)
|
distill_main(args)
|
||||||
@@ -3,6 +3,7 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
@@ -23,12 +24,14 @@ try:
|
|||||||
flatten_list,
|
flatten_list,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
save_git_info,
|
save_git_info,
|
||||||
|
save_json,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
|
calculate_bleu_score,
|
||||||
)
|
)
|
||||||
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import (
|
from utils import (
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
@@ -37,12 +40,14 @@ except ImportError:
|
|||||||
flatten_list,
|
flatten_list,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
save_git_info,
|
save_git_info,
|
||||||
|
save_json,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
|
calculate_bleu_score,
|
||||||
)
|
)
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -50,15 +55,18 @@ logger = logging.getLogger(__name__)
|
|||||||
class SummarizationModule(BaseTransformer):
|
class SummarizationModule(BaseTransformer):
|
||||||
mode = "summarization"
|
mode = "summarization"
|
||||||
loss_names = ["loss"]
|
loss_names = ["loss"]
|
||||||
|
metric_names = ROUGE_KEYS
|
||||||
|
val_metric = "rouge2"
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||||
use_task_specific_params(self.model, "summarization")
|
use_task_specific_params(self.model, "summarization")
|
||||||
save_git_info(self.hparams.output_dir)
|
save_git_info(self.hparams.output_dir)
|
||||||
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
|
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||||
|
pickle_save(self.hparams, self.hparams_save_path)
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.metrics = {"train": [], "val": [], "test": []}
|
self.metrics = defaultdict(list)
|
||||||
|
|
||||||
self.dataset_kwargs: dict = dict(
|
self.dataset_kwargs: dict = dict(
|
||||||
data_dir=self.hparams.data_dir,
|
data_dir=self.hparams.data_dir,
|
||||||
@@ -89,12 +97,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
if self.model.config.model_type == "bart":
|
try:
|
||||||
freeze_params(self.model.model.shared)
|
freeze_params(self.model.model.shared)
|
||||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||||
freeze_params(d.embed_positions)
|
freeze_params(d.embed_positions)
|
||||||
freeze_params(d.embed_tokens)
|
freeze_params(d.embed_tokens)
|
||||||
else:
|
except AttributeError:
|
||||||
freeze_params(self.model.shared)
|
freeze_params(self.model.shared)
|
||||||
for d in [self.model.encoder, self.model.decoder]:
|
for d in [self.model.encoder, self.model.decoder]:
|
||||||
freeze_params(d.embed_tokens)
|
freeze_params(d.embed_tokens)
|
||||||
@@ -130,19 +138,22 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
|
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
|
||||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
|
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
||||||
rouges.update({k: v.item() for k, v in losses.items()})
|
rouges.update({k: v.item() for k, v in losses.items()})
|
||||||
losses.update(rouges)
|
losses.update(rouges)
|
||||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||||
metrics["step_count"] = self.step_count
|
metrics["step_count"] = self.step_count
|
||||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||||
preds = flatten_list([x["preds"] for x in outputs])
|
preds = flatten_list([x["preds"] for x in outputs])
|
||||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
|
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
|
||||||
|
|
||||||
def save_metrics(self, metrics, prefix) -> None:
|
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||||
self.metrics[prefix].append(metrics)
|
self.metrics[type_path].append(latest_metrics)
|
||||||
pickle_save(self.metrics, self.metrics_save_path)
|
save_json(self.metrics, self.metrics_save_path)
|
||||||
|
|
||||||
|
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||||
|
return calculate_rouge(preds, target)
|
||||||
|
|
||||||
def _generative_step(self, batch: dict) -> dict:
|
def _generative_step(self, batch: dict) -> dict:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
@@ -154,7 +165,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
target = self.ids_to_clean_text(y)
|
target = self.ids_to_clean_text(y)
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
rouge: Dict = calculate_rouge(preds, target)
|
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||||
summ_len = np.mean(lmap(len, generated_ids))
|
summ_len = np.mean(lmap(len, generated_ids))
|
||||||
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
||||||
return base_metrics
|
return base_metrics
|
||||||
@@ -259,15 +270,33 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationModule(SummarizationModule):
|
||||||
|
mode = "translation"
|
||||||
|
loss_names = ["loss"]
|
||||||
|
metric_names = ["bleu"]
|
||||||
|
val_metric = "bleu"
|
||||||
|
|
||||||
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
|
return calculate_bleu_score(preds, target)
|
||||||
|
|
||||||
|
|
||||||
def main(args, model=None) -> SummarizationModule:
|
def main(args, model=None) -> SummarizationModule:
|
||||||
Path(args.output_dir).mkdir(exist_ok=True)
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
if model is None:
|
if model is None:
|
||||||
model: BaseTransformer = SummarizationModule(args)
|
if args.task == "summarization":
|
||||||
|
model: SummarizationModule = SummarizationModule(args)
|
||||||
|
else:
|
||||||
|
model: SummarizationModule = TranslationModule(args)
|
||||||
|
|
||||||
|
dataset = Path(args.data_dir).name
|
||||||
if (
|
if (
|
||||||
args.logger == "default"
|
args.logger == "default"
|
||||||
or args.fast_dev_run
|
or args.fast_dev_run
|
||||||
@@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
elif args.logger == "wandb":
|
elif args.logger == "wandb":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name)
|
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||||
|
|
||||||
elif args.logger == "wandb_shared":
|
elif args.logger == "wandb_shared":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
|
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||||
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
|
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
logging_callback=Seq2SeqLoggingCallback(),
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
# TODO: early stopping callback seems messed up
|
# TODO: early stopping callback seems messed up
|
||||||
)
|
)
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
|
|
||||||
# Add parent directory to python path to access lightning_base.py
|
# Add parent directory to python path to access lightning_base.py
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
# --model_name_or_path=t5-base for t5
|
|
||||||
|
|
||||||
# the proper usage is documented in the README
|
|
||||||
python finetune.py \
|
python finetune.py \
|
||||||
--model_name_or_path=facebook/bart-large \
|
|
||||||
--learning_rate=3e-5 \
|
--learning_rate=3e-5 \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--gpus 1 \
|
--gpus 1 \
|
||||||
@@ -16,5 +11,4 @@ python finetune.py \
|
|||||||
--n_val 1000 \
|
--n_val 1000 \
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.1 \
|
||||||
--sortish_sampler \
|
--sortish_sampler \
|
||||||
--max_target_length=56 \
|
|
||||||
$@
|
$@
|
||||||
0
examples/summarization/finetune_bart_tiny.sh → examples/seq2seq/finetune_bart_tiny.sh
Normal file → Executable file
0
examples/summarization/finetune_bart_tiny.sh → examples/seq2seq/finetune_bart_tiny.sh
Normal file → Executable file
0
examples/summarization/finetune_t5.sh → examples/seq2seq/finetune_t5.sh
Normal file → Executable file
0
examples/summarization/finetune_t5.sh → examples/seq2seq/finetune_t5.sh
Normal file → Executable file
@@ -1,5 +1,3 @@
|
|||||||
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
|
|
||||||
|
|
||||||
# Add parent directory to python path to access lightning_base.py
|
# Add parent directory to python path to access lightning_base.py
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .finetune import calculate_rouge, use_task_specific_params
|
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from finetune import calculate_rouge, use_task_specific_params
|
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
@@ -22,8 +22,14 @@ def chunks(lst, n):
|
|||||||
yield lst[i : i + n]
|
yield lst[i : i + n]
|
||||||
|
|
||||||
|
|
||||||
def generate_summaries(
|
def generate_summaries_or_translations(
|
||||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
examples: list,
|
||||||
|
out_file: str,
|
||||||
|
model_name: str,
|
||||||
|
batch_size: int = 8,
|
||||||
|
device: str = DEFAULT_DEVICE,
|
||||||
|
fp16=False,
|
||||||
|
**gen_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
fout = Path(out_file).open("w", encoding="utf-8")
|
fout = Path(out_file).open("w", encoding="utf-8")
|
||||||
model_name = str(model_name)
|
model_name = str(model_name)
|
||||||
@@ -39,11 +45,10 @@ def generate_summaries(
|
|||||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||||
if "t5" in model_name:
|
if "t5" in model_name:
|
||||||
batch = [model.config.prefix + text for text in batch]
|
batch = [model.config.prefix + text for text in batch]
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
batch = tokenizer.batch_encode_plus(
|
||||||
device
|
batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
|
||||||
)
|
).to(device)
|
||||||
summaries = model.generate(**dct)
|
summaries = model.generate(**batch, **gen_kwargs)
|
||||||
|
|
||||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
for hypothesis in dec:
|
for hypothesis in dec:
|
||||||
fout.write(hypothesis + "\n")
|
fout.write(hypothesis + "\n")
|
||||||
@@ -57,22 +62,26 @@ def run_generate():
|
|||||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||||
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
||||||
|
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
|
||||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||||
|
|
||||||
generate_summaries(
|
generate_summaries_or_translations(
|
||||||
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
||||||
)
|
)
|
||||||
if args.score_path is not None:
|
|
||||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||||
|
scores = {}
|
||||||
|
if args.reference_path is not None:
|
||||||
|
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
|
||||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||||
|
scores: dict = score_fn(output_lns, reference_lns)
|
||||||
rouge: dict = calculate_rouge(output_lns, reference_lns)
|
if args.score_path is not None:
|
||||||
|
json.dump(scores, open("score_path", "w+"))
|
||||||
json.dump(rouge, open("score_path", "w+"))
|
return scores
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
|
from .finetune import main
|
||||||
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
|
from .utils import SummarizationDataset, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
|
CHEAP_ARGS = {
|
||||||
|
"logger": "default",
|
||||||
|
"length_penalty": 0.5,
|
||||||
|
"cache_dir": "",
|
||||||
|
"task": "summarization",
|
||||||
|
"num_workers": 2,
|
||||||
|
"alpha_hid": 0,
|
||||||
|
"freeze_embeds": True,
|
||||||
|
"enc_only": False,
|
||||||
|
"tgt_suffix": "",
|
||||||
|
"resume_from_checkpoint": None,
|
||||||
|
"sortish_sampler": True,
|
||||||
|
"student_decoder_layers": 1,
|
||||||
|
"val_check_interval": 1.0,
|
||||||
|
"output_dir": "",
|
||||||
|
"fp16": CUDA_AVAILABLE,
|
||||||
|
"no_teacher": False,
|
||||||
|
"fp16_opt_level": "O1",
|
||||||
|
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||||
|
"n_tpu_cores": 0,
|
||||||
|
"max_grad_norm": 1.0,
|
||||||
|
"do_train": True,
|
||||||
|
"do_predict": True,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"server_ip": "",
|
||||||
|
"server_port": "",
|
||||||
|
"seed": 42,
|
||||||
|
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||||
|
"config_name": "",
|
||||||
|
"tokenizer_name": "facebook/bart-large",
|
||||||
|
"do_lower_case": False,
|
||||||
|
"learning_rate": 0.3,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"adam_epsilon": 1e-08,
|
||||||
|
"warmup_steps": 0,
|
||||||
|
"num_train_epochs": 1,
|
||||||
|
"train_batch_size": 2,
|
||||||
|
"eval_batch_size": 2,
|
||||||
|
"max_source_length": 12,
|
||||||
|
"max_target_length": 12,
|
||||||
|
"val_max_target_length": 12,
|
||||||
|
"test_max_target_length": 12,
|
||||||
|
"fast_dev_run": False,
|
||||||
|
"no_cache": False,
|
||||||
|
"n_train": -1,
|
||||||
|
"n_val": -1,
|
||||||
|
"n_test": -1,
|
||||||
|
"student_encoder_layers": 1,
|
||||||
|
"alpha_loss_encoder": 0.0,
|
||||||
|
"freeze_encoder": False,
|
||||||
|
"auto_scale_batch_size": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_articles(path: Path, articles: list):
|
||||||
|
with path.open("w") as f:
|
||||||
|
f.write("\n".join(articles))
|
||||||
|
|
||||||
|
|
||||||
|
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||||
|
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||||
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||||
|
BART_TINY = "sshleifer/bart-tiny-random"
|
||||||
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
|
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_data_dir(**kwargs):
|
||||||
|
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
||||||
|
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
||||||
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarizationDistiller(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||||
|
def test_multigpu(self):
|
||||||
|
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
||||||
|
self._test_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_distill_no_teacher(self):
|
||||||
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||||
|
self._test_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_distill_checkpointing_with_teacher(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=2,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
num_train_epochs=4,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
alpha_hid=2.0,
|
||||||
|
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||||
|
)
|
||||||
|
model = self._test_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
|
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||||
|
self.assertEqual(1, len(ckpts))
|
||||||
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
|
self.assertEqual(len(transformer_ckpts), 2)
|
||||||
|
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||||
|
out_path = tempfile.mktemp()
|
||||||
|
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||||
|
self.assertTrue(Path(out_path).exists())
|
||||||
|
|
||||||
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
|
@unittest.skip("T5 distillation is broken at the moment")
|
||||||
|
def test_distill_t5(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=1,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
alpha_hid=2.0,
|
||||||
|
teacher=T5_TINY,
|
||||||
|
model_name_or_path=T5_TINY,
|
||||||
|
tokenizer_name=T5_TINY,
|
||||||
|
)
|
||||||
|
self._test_distiller_cli(updates)
|
||||||
|
|
||||||
|
def _test_distiller_cli(self, updates, check_contents=True):
|
||||||
|
default_updates = dict(
|
||||||
|
train_batch_size=1,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_train_epochs=2,
|
||||||
|
alpha_mlm=0.2,
|
||||||
|
alpha_ce=0.8,
|
||||||
|
do_predict=True,
|
||||||
|
model_name_or_path="sshleifer/tinier_bart",
|
||||||
|
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||||
|
val_check_interval=0.5,
|
||||||
|
alpha_encoder_loss=0.4,
|
||||||
|
)
|
||||||
|
default_updates.update(updates)
|
||||||
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
|
|
||||||
|
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||||
|
model = distill_main(argparse.Namespace(**args_d))
|
||||||
|
if not check_contents:
|
||||||
|
return model
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
self.assertIn(ckpt_name, contents)
|
||||||
|
|
||||||
|
self.assertIn("test_generations.txt", contents)
|
||||||
|
self.assertIn("test_results.txt", contents)
|
||||||
|
|
||||||
|
metrics = load_json(model.metrics_save_path)
|
||||||
|
last_step_stats = metrics["val"][-1]
|
||||||
|
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||||
|
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||||
|
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||||
|
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||||
|
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||||
|
self.assertEqual(len(metrics["test"]), 1)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||||
|
def test_run_eval_bart(model):
|
||||||
|
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||||
|
|
||||||
|
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||||
|
assert not output_file_name.exists()
|
||||||
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
|
_dump_articles(tmp, articles)
|
||||||
|
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_generate()
|
||||||
|
assert Path(output_file_name).exists()
|
||||||
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||||
|
)
|
||||||
|
def test_finetune(model):
|
||||||
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
|
args_d.update(
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
model_name_or_path=model,
|
||||||
|
tokenizer_name=None,
|
||||||
|
train_batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
output_dir=output_dir,
|
||||||
|
do_predict=True,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
assert "n_train" in args_d
|
||||||
|
args = argparse.Namespace(**args_d)
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||||
|
)
|
||||||
|
def test_dataset(tok):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
|
trunc_target = 4
|
||||||
|
train_dataset = SummarizationDataset(
|
||||||
|
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
|
for batch in dataloader:
|
||||||
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||||
|
# show that articles were trimmed.
|
||||||
|
assert batch["input_ids"].shape[1] == max_len_source
|
||||||
|
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||||
|
# show that targets were truncated
|
||||||
|
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||||
|
assert max_len_target > trunc_target # Truncated
|
||||||
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
@@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
export BS=32
|
||||||
|
export GAS=1
|
||||||
|
|
||||||
|
python finetune.py \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 1 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--val_check_interval 0.25 \
|
||||||
|
--n_val 500 \
|
||||||
|
--num_train_epochs 2 \
|
||||||
|
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
||||||
|
--max_target_length 142 --val_max_target_length=142 \
|
||||||
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||||
|
--data_dir $CNN_DIR \
|
||||||
|
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||||
|
--tokenizer_name facebook/bart-large \
|
||||||
|
--output_dir distilbart-cnn-12-6 \
|
||||||
|
$@
|
||||||
|
|
||||||
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
export BS=16
|
||||||
|
export GAS=2
|
||||||
|
python distillation.py \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--fp16 \
|
||||||
|
--val_check_interval 0.1 --n_val 1000 \
|
||||||
|
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
|
||||||
|
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||||
|
--student_decoder_layers 6 --student_encoder_layers 12 \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--model_name_or_path IGNORED \
|
||||||
|
--alpha_hid=3. --length_penalty=0.5 \
|
||||||
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
||||||
|
--tokenizer_name facebook/bart-large \
|
||||||
|
--output_dir distilbart_xsum_12_6 \
|
||||||
|
$@
|
||||||
@@ -3,12 +3,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List
|
from typing import Callable, Dict, Iterable, List
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from rouge_score import rouge_scorer, scoring
|
from rouge_score import rouge_scorer, scoring
|
||||||
|
from sacrebleu import corpus_bleu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -41,7 +42,7 @@ def encode_file(
|
|||||||
examples = []
|
examples = []
|
||||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||||
tokenized = tokenizer.batch_encode_plus(
|
tokenized = tokenizer.batch_encode_plus(
|
||||||
[text], # DONT ADD SPACES
|
[text],
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
add_prefix_space=True,
|
add_prefix_space=True,
|
||||||
@@ -54,11 +55,13 @@ def encode_file(
|
|||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def lmap(f, x):
|
def lmap(f: Callable, x: Iterable) -> List:
|
||||||
|
"""list(map(f, x))"""
|
||||||
return list(map(f, x))
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
T5_PREFIX = "summarize: " # HACK, fixme
|
def calculate_bleu_score(output_lns, refs_lns) -> dict:
|
||||||
|
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
|
||||||
|
|
||||||
|
|
||||||
def trim_batch(
|
def trim_batch(
|
||||||
@@ -95,6 +98,8 @@ class SummarizationDataset(Dataset):
|
|||||||
tok_name=tok_name,
|
tok_name=tok_name,
|
||||||
)
|
)
|
||||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||||
|
if hasattr(tokenizer, "set_lang"):
|
||||||
|
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||||
self.target = encode_file(
|
self.target = encode_file(
|
||||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||||
)
|
)
|
||||||
@@ -189,14 +194,20 @@ def flatten_list(summary_ids: List[List]):
|
|||||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||||
|
|
||||||
|
|
||||||
def save_git_info(folder_path: str):
|
def save_git_info(folder_path: str) -> None:
|
||||||
"""
|
"""Save git information to output_dir/git_log.json"""
|
||||||
Log commit info.
|
|
||||||
"""
|
|
||||||
repo_infos = get_git_info()
|
repo_infos = get_git_info()
|
||||||
|
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||||
|
|
||||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
|
||||||
json.dump(repo_infos, f, indent=4)
|
def save_json(content, path):
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(content, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path):
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def get_git_info():
|
def get_git_info():
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
### Data
|
|
||||||
|
|
||||||
CNN/DailyMail data
|
|
||||||
```bash
|
|
||||||
cd examples/summarization
|
|
||||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
|
||||||
tar -xzvf cnn_dm.tgz
|
|
||||||
export CNN_DIR=${PWD}/cnn_dm
|
|
||||||
```
|
|
||||||
|
|
||||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
|
||||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
|
||||||
|
|
||||||
XSUM Data:
|
|
||||||
```bash
|
|
||||||
cd examples/summarization
|
|
||||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
|
||||||
tar -xzvf xsum.tar.gz
|
|
||||||
export XSUM_DIR=${PWD}/xsum
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Evaluation
|
|
||||||
|
|
||||||
To create summaries for each article in dataset, run:
|
|
||||||
```bash
|
|
||||||
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
|
||||||
```
|
|
||||||
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
|
||||||
|
|
||||||
|
|
||||||
### Training
|
|
||||||
Run/modify `finetune.sh`
|
|
||||||
|
|
||||||
The following command should work on a 16GB GPU:
|
|
||||||
```bash
|
|
||||||
export me=`git config user.name`
|
|
||||||
./finetune.sh \
|
|
||||||
--data_dir $XSUM_DIR \
|
|
||||||
--train_batch_size=1 \
|
|
||||||
--eval_batch_size=1 \
|
|
||||||
--output_dir="$me"_xsum_results \
|
|
||||||
--num_train_epochs 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Tips:
|
|
||||||
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
|
||||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
|
|
||||||
- `fp16_opt_level=O1` (the default works best).
|
|
||||||
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
|
|
||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
|
||||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
|
||||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
|
||||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
|
||||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
|
||||||
|
|
||||||
### XSUM Shared Task
|
|
||||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
|
||||||
Here is an example command
|
|
||||||
```bash
|
|
||||||
export me=`git config user.name`
|
|
||||||
./finetune.sh \
|
|
||||||
--data_dir $XSUM_DIR \
|
|
||||||
--output_dir "$me"_xsum_frozen_embs \
|
|
||||||
--logger wandb_shared \
|
|
||||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
|
||||||
--num_train_epochs 6
|
|
||||||
```
|
|
||||||
|
|
||||||
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)
|
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from transformers import BartTokenizer
|
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
|
||||||
from .finetune import main
|
|
||||||
from .run_eval import generate_summaries, run_generate
|
|
||||||
from .utils import SummarizationDataset, lmap, pickle_load
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
FP16_EVER = False
|
|
||||||
CHEAP_ARGS = {
|
|
||||||
"logger": "default",
|
|
||||||
"num_workers": 2,
|
|
||||||
"alpha_hid": 0,
|
|
||||||
"freeze_embeds": True,
|
|
||||||
"enc_only": False,
|
|
||||||
"tgt_suffix": "",
|
|
||||||
"resume_from_checkpoint": None,
|
|
||||||
"sortish_sampler": True,
|
|
||||||
"student_decoder_layers": 1,
|
|
||||||
"val_check_interval": 1.0,
|
|
||||||
"output_dir": "",
|
|
||||||
"fp16": False,
|
|
||||||
"no_teacher": False,
|
|
||||||
"fp16_opt_level": "O1",
|
|
||||||
"gpus": 1 if torch.cuda.is_available() else 0,
|
|
||||||
"n_tpu_cores": 0,
|
|
||||||
"max_grad_norm": 1.0,
|
|
||||||
"do_train": True,
|
|
||||||
"do_predict": True,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"server_ip": "",
|
|
||||||
"server_port": "",
|
|
||||||
"seed": 42,
|
|
||||||
"model_type": "bart",
|
|
||||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
|
||||||
"config_name": "",
|
|
||||||
"tokenizer_name": "facebook/bart-large",
|
|
||||||
"cache_dir": "",
|
|
||||||
"do_lower_case": False,
|
|
||||||
"learning_rate": 3e-05,
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"adam_epsilon": 1e-08,
|
|
||||||
"warmup_steps": 0,
|
|
||||||
"num_train_epochs": 1,
|
|
||||||
"train_batch_size": 2,
|
|
||||||
"eval_batch_size": 2,
|
|
||||||
"max_source_length": 12,
|
|
||||||
"max_target_length": 12,
|
|
||||||
"val_max_target_length": 12,
|
|
||||||
"test_max_target_length": 12,
|
|
||||||
"fast_dev_run": False,
|
|
||||||
"no_cache": False,
|
|
||||||
"n_train": -1,
|
|
||||||
"n_val": -1,
|
|
||||||
"n_test": -1,
|
|
||||||
"student_encoder_layers": 1,
|
|
||||||
"alpha_loss_encoder": 0.0,
|
|
||||||
"freeze_encoder": False,
|
|
||||||
"auto_scale_batch_size": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _dump_articles(path: Path, articles: list):
|
|
||||||
with path.open("w") as f:
|
|
||||||
f.write("\n".join(articles))
|
|
||||||
|
|
||||||
|
|
||||||
MSG = "T5 is broken at the moment"
|
|
||||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir():
|
|
||||||
tmp_dir = Path(tempfile.gettempdir())
|
|
||||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
|
||||||
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
|
||||||
for split in ["train", "val", "test"]:
|
|
||||||
_dump_articles((tmp_dir / f"{split}.source"), articles)
|
|
||||||
_dump_articles((tmp_dir / f"{split}.target"), summaries)
|
|
||||||
return tmp_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestSummarizationDistiller(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
|
||||||
def test_bdc_multigpu(self):
|
|
||||||
updates = dict(
|
|
||||||
student_encoder_layers=2,
|
|
||||||
student_decoder_layers=1,
|
|
||||||
no_teacher=True,
|
|
||||||
freeze_encoder=True,
|
|
||||||
gpus=2,
|
|
||||||
sortish_sampler=False,
|
|
||||||
fp16_opt_level="O1",
|
|
||||||
fp16=FP16_EVER,
|
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
def test_bdc_t5_train(self):
|
|
||||||
updates = dict(
|
|
||||||
fp16=FP16_EVER,
|
|
||||||
gpus=1 if torch.cuda.is_available() else 0,
|
|
||||||
model_type="t5",
|
|
||||||
model_name_or_path=T5_TINY,
|
|
||||||
do_train=True,
|
|
||||||
do_predict=True,
|
|
||||||
tokenizer_name=T5_TINY,
|
|
||||||
no_teacher=True,
|
|
||||||
alpha_hid=2.0,
|
|
||||||
)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
def test_bdc_no_teacher(self):
|
|
||||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
def test_bdc_yes_teacher(self):
|
|
||||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
|
|
||||||
self._bart_distiller_cli(updates)
|
|
||||||
|
|
||||||
def test_bdc_checkpointing(self):
|
|
||||||
updates = dict(
|
|
||||||
student_encoder_layers=2,
|
|
||||||
student_decoder_layers=1,
|
|
||||||
num_train_epochs=4,
|
|
||||||
val_check_interval=0.25,
|
|
||||||
alpha_hid=2.0,
|
|
||||||
)
|
|
||||||
model = self._bart_distiller_cli(updates, check_contents=False)
|
|
||||||
|
|
||||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
|
||||||
self.assertEqual(1, len(ckpts))
|
|
||||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
|
||||||
self.assertEqual(len(transformer_ckpts), len(ckpts))
|
|
||||||
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
|
||||||
self.assertEqual(len(new_transformer_ckpts), 1)
|
|
||||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
|
||||||
out_path = tempfile.mktemp()
|
|
||||||
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
|
|
||||||
self.assertTrue(Path(out_path).exists())
|
|
||||||
|
|
||||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
|
||||||
|
|
||||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
|
||||||
default_updates = dict(
|
|
||||||
train_batch_size=1,
|
|
||||||
eval_batch_size=2,
|
|
||||||
num_train_epochs=2,
|
|
||||||
alpha_mlm=0.2,
|
|
||||||
alpha_ce=0.8,
|
|
||||||
do_predict=True,
|
|
||||||
gpus=1 if torch.cuda.is_available() else 0,
|
|
||||||
model_name_or_path="sshleifer/tinier_bart",
|
|
||||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
|
||||||
val_check_interval=0.5,
|
|
||||||
alpha_encoder_loss=0.4,
|
|
||||||
)
|
|
||||||
default_updates.update(updates)
|
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
|
||||||
tmp_dir = make_test_data_dir()
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
|
||||||
|
|
||||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
|
||||||
model = distill_main(argparse.Namespace(**args_d))
|
|
||||||
if not check_contents:
|
|
||||||
return model
|
|
||||||
contents = os.listdir(output_dir)
|
|
||||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
|
||||||
contents = {os.path.basename(p) for p in contents}
|
|
||||||
self.assertIn(ckpt_name, contents)
|
|
||||||
self.assertIn("metrics.pkl", contents)
|
|
||||||
self.assertIn("test_generations.txt", contents)
|
|
||||||
self.assertIn("val_generations_00001.txt", contents)
|
|
||||||
self.assertIn("val_results_00001.txt", contents)
|
|
||||||
self.assertIn("test_results.txt", contents)
|
|
||||||
|
|
||||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
|
||||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
|
||||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
|
||||||
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class TestBartExamples(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
||||||
return cls
|
|
||||||
|
|
||||||
def test_bart_cnn_cli(self):
|
|
||||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
|
||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
|
||||||
_dump_articles(tmp, articles)
|
|
||||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
run_generate()
|
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
|
||||||
os.remove(Path(output_file_name))
|
|
||||||
|
|
||||||
def test_t5_run_sum_cli(self):
|
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
|
||||||
args_d.update(
|
|
||||||
data_dir=tmp_dir,
|
|
||||||
model_name_or_path=T5_TINY,
|
|
||||||
tokenizer_name=None, # T5_TINY,
|
|
||||||
train_batch_size=2,
|
|
||||||
eval_batch_size=2,
|
|
||||||
gpus=0,
|
|
||||||
output_dir=output_dir,
|
|
||||||
do_predict=True,
|
|
||||||
)
|
|
||||||
assert "n_train" in args_d
|
|
||||||
args = argparse.Namespace(**args_d)
|
|
||||||
main(args)
|
|
||||||
|
|
||||||
def test_bart_summarization_dataset(self):
|
|
||||||
tmp_dir = Path(tempfile.gettempdir())
|
|
||||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
|
||||||
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
|
||||||
_dump_articles((tmp_dir / "train.source"), articles)
|
|
||||||
_dump_articles((tmp_dir / "train.target"), summaries)
|
|
||||||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
|
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
|
|
||||||
trunc_target = 4
|
|
||||||
train_dataset = SummarizationDataset(
|
|
||||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
|
||||||
)
|
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
|
||||||
for batch in dataloader:
|
|
||||||
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
|
|
||||||
# show that articles were trimmed.
|
|
||||||
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
|
|
||||||
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
|
|
||||||
|
|
||||||
# show that targets were truncated
|
|
||||||
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
|
|
||||||
self.assertGreater(max_len_target, trunc_target) # Truncated
|
|
||||||
|
|
||||||
|
|
||||||
def list_to_text_file(lst, path):
|
|
||||||
dest = Path(path)
|
|
||||||
dest.open("w+").writelines(lst)
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points***
|
|
||||||
|
|
||||||
### Intro
|
|
||||||
|
|
||||||
This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be
|
|
||||||
evaluated on the WMT English-German dataset.
|
|
||||||
|
|
||||||
### Get the WMT Data
|
|
||||||
|
|
||||||
To be able to reproduce the authors' results on WMT English to German, you first need to download
|
|
||||||
the WMT14 en-de news datasets.
|
|
||||||
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en
|
|
||||||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de
|
|
||||||
```
|
|
||||||
|
|
||||||
You should have 2737 sentences in each file. You can verify this by running:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
wc -l newstest2014.en # should give 2737
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters:
|
|
||||||
|
|
||||||
Get the longest and shortest sentence:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word
|
|
||||||
awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words
|
|
||||||
```
|
|
||||||
|
|
||||||
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
|
|
||||||
We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`.
|
|
||||||
|
|
||||||
To create translation for each in dataset and get a final BLEU score, run:
|
|
||||||
```bash
|
|
||||||
python evaluate_wmt.py <path_to_newstest2014.en> newstest2014_de_translations.txt <path_to_newstest2014.de> newsstest2014_en_de_bleu.txt
|
|
||||||
```
|
|
||||||
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
|
||||||
|
|
||||||
### Where is the code?
|
|
||||||
The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples.
|
|
||||||
|
|
||||||
### BLEU Scores
|
|
||||||
|
|
||||||
The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost.
|
|
||||||
To get the BLEU score we used
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from sacrebleu import corpus_bleu
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def chunks(lst, n):
|
|
||||||
"""Yield successive n-sized chunks from lst."""
|
|
||||||
for i in range(0, len(lst), n):
|
|
||||||
yield lst[i : i + n]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_translations(lns, output_file_path, model_size, batch_size, device):
|
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_size)
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained(model_size)
|
|
||||||
|
|
||||||
# update config with summarization specific params
|
|
||||||
task_specific_params = model.config.task_specific_params
|
|
||||||
if task_specific_params is not None:
|
|
||||||
model.config.update(task_specific_params.get("translation_en_to_de", {}))
|
|
||||||
|
|
||||||
with Path(output_file_path).open("w") as output_file:
|
|
||||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
|
||||||
batch = [model.config.prefix + text for text in batch]
|
|
||||||
|
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
|
|
||||||
|
|
||||||
input_ids = dct["input_ids"].to(device)
|
|
||||||
attention_mask = dct["attention_mask"].to(device)
|
|
||||||
|
|
||||||
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
dec = [
|
|
||||||
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations
|
|
||||||
]
|
|
||||||
|
|
||||||
for hypothesis in dec:
|
|
||||||
output_file.write(hypothesis + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_bleu_score(output_lns, refs_lns, score_path):
|
|
||||||
bleu = corpus_bleu(output_lns, [refs_lns])
|
|
||||||
result = "BLEU score: {}".format(bleu.score)
|
|
||||||
with Path(score_path).open("w") as score_file:
|
|
||||||
score_file.write(result)
|
|
||||||
|
|
||||||
|
|
||||||
def run_generate():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"model_size",
|
|
||||||
type=str,
|
|
||||||
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
|
|
||||||
default="t5-base",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"input_path", type=str, help="like wmt/newstest2014.en",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"output_path", type=str, help="where to save translation",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"reference_path", type=str, help="like wmt/newstest2014.de",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"score_path", type=str, help="where to save the bleu score",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
||||||
|
|
||||||
dash_pattern = (" ##AT##-##AT## ", "-")
|
|
||||||
|
|
||||||
# Read input lines into python
|
|
||||||
with open(args.input_path, "r") as input_file:
|
|
||||||
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()]
|
|
||||||
|
|
||||||
generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
|
|
||||||
|
|
||||||
# Read generated lines into python
|
|
||||||
with open(args.output_path, "r") as output_file:
|
|
||||||
output_lns = [x.strip() for x in output_file.readlines()]
|
|
||||||
|
|
||||||
# Read reference lines into python
|
|
||||||
with open(args.reference_path, "r") as reference_file:
|
|
||||||
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()]
|
|
||||||
|
|
||||||
calculate_bleu_score(output_lns, refs_lns, args.score_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_generate()
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
import logging
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from .evaluate_wmt import run_generate
|
|
||||||
|
|
||||||
|
|
||||||
text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
|
||||||
translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."]
|
|
||||||
|
|
||||||
output_file_name = "output_t5_trans.txt"
|
|
||||||
score_file_name = "score_t5_trans.txt"
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class TestT5Examples(unittest.TestCase):
|
|
||||||
def test_t5_cli(self):
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo"
|
|
||||||
with tmp_source.open("w") as f:
|
|
||||||
f.write("\n".join(text))
|
|
||||||
|
|
||||||
tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target"
|
|
||||||
with tmp_target.open("w") as f:
|
|
||||||
f.write("\n".join(translation))
|
|
||||||
|
|
||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
|
|
||||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"
|
|
||||||
|
|
||||||
testargs = [
|
|
||||||
"evaluate_wmt.py",
|
|
||||||
"patrickvonplaten/t5-tiny-random",
|
|
||||||
str(tmp_source),
|
|
||||||
str(output_file_name),
|
|
||||||
str(tmp_target),
|
|
||||||
str(score_file_name),
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
run_generate()
|
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
|
||||||
self.assertTrue(Path(score_file_name).exists())
|
|
||||||
@@ -20,6 +20,7 @@ known_third_party =
|
|||||||
pandas
|
pandas
|
||||||
PIL
|
PIL
|
||||||
psutil
|
psutil
|
||||||
|
pytest
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
rouge_score
|
rouge_score
|
||||||
sacrebleu
|
sacrebleu
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
_all_mbart_models = ["facebook/mbart-large-en-ro"]
|
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
|
||||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||||
|
|
||||||
|
|
||||||
@@ -105,6 +105,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
"vi_VN": 250024,
|
"vi_VN": 250024,
|
||||||
"zh_CN": 250025,
|
"zh_CN": 250025,
|
||||||
}
|
}
|
||||||
|
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
||||||
cur_lang_code = lang_code_to_id["en_XX"]
|
cur_lang_code = lang_code_to_id["en_XX"]
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
@@ -115,6 +116,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + special_tokens
|
return token_ids_0 + token_ids_1 + special_tokens
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
if index in self.id_to_lang_code:
|
||||||
|
return self.id_to_lang_code[index]
|
||||||
|
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||||
|
|
||||||
|
def set_lang(self, lang: str) -> None:
|
||||||
|
"""Set the current language code in order to call batch_encode_plus properly."""
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||||
|
|
||||||
def prepare_translation_batch(
|
def prepare_translation_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
|
|||||||
Reference in New Issue
Block a user