Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
@@ -7,6 +7,15 @@ For `bertabs` instructions, see `bertabs/README.md`.
|
|||||||
|
|
||||||
|
|
||||||
### Data
|
### Data
|
||||||
|
XSUM Data:
|
||||||
|
```bash
|
||||||
|
cd examples/seq2seq
|
||||||
|
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||||
|
tar -xzvf xsum.tar.gz
|
||||||
|
export XSUM_DIR=${PWD}/xsum
|
||||||
|
```
|
||||||
|
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||||
|
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||||
|
|
||||||
CNN/DailyMail data
|
CNN/DailyMail data
|
||||||
```bash
|
```bash
|
||||||
@@ -17,18 +26,6 @@ tar -xzvf cnn_dm.tgz
|
|||||||
export CNN_DIR=${PWD}/cnn_dm
|
export CNN_DIR=${PWD}/cnn_dm
|
||||||
```
|
```
|
||||||
|
|
||||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
|
||||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
|
||||||
|
|
||||||
XSUM Data:
|
|
||||||
```bash
|
|
||||||
cd examples/seq2seq
|
|
||||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
|
||||||
tar -xzvf xsum.tar.gz
|
|
||||||
export XSUM_DIR=${PWD}/xsum
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
WMT16 English-Romanian Translation Data:
|
WMT16 English-Romanian Translation Data:
|
||||||
```bash
|
```bash
|
||||||
cd examples/seq2seq
|
cd examples/seq2seq
|
||||||
@@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro
|
|||||||
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||||
The `.source` files are the input, the `.target` files are the desired output.
|
The `.source` files are the input, the `.target` files are the desired output.
|
||||||
|
|
||||||
|
|
||||||
### Tips and Tricks
|
### Tips and Tricks
|
||||||
|
|
||||||
General Tips:
|
General Tips:
|
||||||
@@ -64,6 +61,10 @@ Summarization Tips:
|
|||||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
|
||||||
|
**Update 2018-07-18**
|
||||||
|
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
||||||
|
A new dataset is needed to support multilingual tasks.
|
||||||
|
|
||||||
### Summarization Finetuning
|
### Summarization Finetuning
|
||||||
Run/modify `finetune.sh`
|
Run/modify `finetune.sh`
|
||||||
|
|
||||||
@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
|
|||||||
--model_name_or_path facebook/bart-large
|
--model_name_or_path facebook/bart-large
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Translation Finetuning
|
### Translation Finetuning
|
||||||
|
|
||||||
First, follow the wmt_en_ro download instructions.
|
First, follow the wmt_en_ro download instructions.
|
||||||
@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
|
|||||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||||
```
|
```
|
||||||
|
|
||||||
#### XSUM Shared Task
|
|
||||||
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
|
|
||||||
|
|
||||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
|
||||||
```bash
|
|
||||||
WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
|
||||||
--data_dir $XSUM_DIR \
|
|
||||||
--output_dir xsum_frozen_embs \
|
|
||||||
--model_name_or_path facebook/bart-large \
|
|
||||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
|
||||||
--num_train_epochs 6 \
|
|
||||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
|
||||||
--logger_name wandb
|
|
||||||
```
|
|
||||||
|
|
||||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
|
||||||
|
|
||||||
### Evaluation Commands
|
### Evaluation Commands
|
||||||
|
|
||||||
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
||||||
|
|||||||
@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from .finetune import SummarizationModule
|
from .finetune import SummarizationModule
|
||||||
from .initialization_utils import init_student, copy_layers
|
|
||||||
from .utils import (
|
|
||||||
use_task_specific_params,
|
|
||||||
SummarizationDataset,
|
|
||||||
pickle_load,
|
|
||||||
freeze_params,
|
|
||||||
assert_all_frozen,
|
|
||||||
any_requires_grad,
|
|
||||||
)
|
|
||||||
from .finetune import main as ft_main
|
from .finetune import main as ft_main
|
||||||
|
from .initialization_utils import init_student, copy_layers
|
||||||
|
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from finetune import SummarizationModule
|
from finetune import SummarizationModule
|
||||||
from finetune import main as ft_main
|
from finetune import main as ft_main
|
||||||
from initialization_utils import init_student, copy_layers
|
from initialization_utils import init_student, copy_layers
|
||||||
from utils import (
|
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||||
use_task_specific_params,
|
|
||||||
SummarizationDataset,
|
|
||||||
pickle_load,
|
|
||||||
freeze_params,
|
|
||||||
assert_all_frozen,
|
|
||||||
any_requires_grad,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
if self.different_encoder:
|
if self.different_encoder:
|
||||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||||
|
|
||||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
|
||||||
n_obs = self.n_obs[type_path]
|
|
||||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# mask has False at padding_idx
|
# mask has False at padding_idx
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ try:
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
SummarizationDataset,
|
|
||||||
lmap,
|
lmap,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
@@ -32,12 +31,17 @@ try:
|
|||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
|
Seq2SeqDataset,
|
||||||
|
MBartDataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import (
|
from utils import (
|
||||||
|
Seq2SeqDataset,
|
||||||
|
MBartDataset,
|
||||||
|
assert_all_frozen,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
SummarizationDataset,
|
|
||||||
lmap,
|
lmap,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
@@ -48,7 +52,6 @@ except ImportError:
|
|||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
assert_all_frozen,
|
|
||||||
)
|
)
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||||
|
|
||||||
@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
self.num_workers = hparams.num_workers
|
self.num_workers = hparams.num_workers
|
||||||
self.decoder_start_token_id = None
|
self.decoder_start_token_id = None
|
||||||
|
self.dataset_class = Seq2SeqDataset
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def _generative_step(self, batch: dict) -> dict:
|
def _generative_step(self, batch: dict) -> dict:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
input_ids=source_ids,
|
input_ids=source_ids,
|
||||||
@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
|
|||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
return self.validation_epoch_end(outputs, prefix="test")
|
return self.validation_epoch_end(outputs, prefix="test")
|
||||||
|
|
||||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||||
n_obs = self.n_obs[type_path]
|
n_obs = self.n_obs[type_path]
|
||||||
max_target_length = self.target_lens[type_path]
|
max_target_length = self.target_lens[type_path]
|
||||||
dataset = SummarizationDataset(
|
dataset = self.dataset_class(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
type_path=type_path,
|
type_path=type_path,
|
||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
|
|||||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
|
if isinstance(self.tokenizer, MBartTokenizer):
|
||||||
|
self.dataset_class = MBartDataset
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> dict:
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
return calculate_bleu_score(preds, target)
|
return calculate_bleu_score(preds, target)
|
||||||
|
|||||||
@@ -9,16 +9,17 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from pytest import param
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, MBartTokenizer
|
||||||
from transformers.testing_utils import require_multigpu
|
from transformers.testing_utils import require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import main
|
from .finetune import main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
from .utils import SummarizationDataset, lmap, load_json
|
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
|
"label_smoothing_eps": 0.2,
|
||||||
"logger_name": "default",
|
"logger_name": "default",
|
||||||
"length_penalty": 0.5,
|
"length_penalty": 0.5,
|
||||||
"cache_dir": "",
|
"cache_dir": "",
|
||||||
@@ -80,11 +82,11 @@ CHEAP_ARGS = {
|
|||||||
|
|
||||||
|
|
||||||
def _dump_articles(path: Path, articles: list):
|
def _dump_articles(path: Path, articles: list):
|
||||||
with path.open("w") as f:
|
content = "\n".join(articles)
|
||||||
f.write("\n".join(articles))
|
Path(path).open("w").writelines(content)
|
||||||
|
|
||||||
|
|
||||||
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||||
BART_TINY = "sshleifer/bart-tiny-random"
|
BART_TINY = "sshleifer/bart-tiny-random"
|
||||||
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||||
)
|
)
|
||||||
def test_finetune(model):
|
def test_finetune(model):
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
@@ -260,22 +262,50 @@ def test_pack_dataset():
|
|||||||
assert orig_paths == new_paths
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_mbart_dataset_truncation():
|
||||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
|
||||||
)
|
tmp_dir = make_test_data_dir()
|
||||||
def test_dataset(tok):
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
|
trunc = 4
|
||||||
|
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||||
|
train_dataset = MBartDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=trunc,
|
||||||
|
max_target_length=1000, # ignored
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
|
for batch in dataloader:
|
||||||
|
assert isinstance(batch, dict)
|
||||||
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||||
|
# show that articles were trimmed.
|
||||||
|
assert batch["input_ids"].shape[1] == trunc
|
||||||
|
# show that targets are the same len
|
||||||
|
assert batch["decoder_input_ids"].shape[1] == trunc
|
||||||
|
# check language codes in correct place
|
||||||
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||||
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||||
|
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||||
|
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||||
|
|
||||||
|
assert max_len_target > trunc # Truncated
|
||||||
|
assert max_len_source > trunc
|
||||||
|
break # No need to test every batch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
||||||
|
def test_summarization_dataset_truncation(tok):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
trunc_target = 4
|
trunc_target = 4
|
||||||
train_dataset = SummarizationDataset(
|
train_dataset = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||||
data_dir=tmp_dir,
|
|
||||||
type_path="train",
|
|
||||||
max_source_length=20,
|
|
||||||
max_target_length=trunc_target,
|
|
||||||
tgt_lang="ro_RO",
|
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
@@ -286,3 +316,4 @@ def test_dataset(tok):
|
|||||||
# show that targets were truncated
|
# show that targets were truncated
|
||||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||||
assert max_len_target > trunc_target # Truncated
|
assert max_len_target > trunc_target # Truncated
|
||||||
|
break # No need to test every batch
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import linecache
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import warnings
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List
|
from typing import Callable, Dict, Iterable, List
|
||||||
@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring
|
|||||||
from sacrebleu import corpus_bleu
|
from sacrebleu import corpus_bleu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
|
|
||||||
def encode_file(
|
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
tokenizer,
|
|
||||||
data_path,
|
|
||||||
max_length,
|
|
||||||
pad_to_max_length=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
overwrite_cache=False,
|
|
||||||
prefix="",
|
|
||||||
tok_name="",
|
|
||||||
):
|
|
||||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
return tokenizer(
|
||||||
if not overwrite_cache and cache_path.exists():
|
[line],
|
||||||
try:
|
max_length=max_length,
|
||||||
examples = torch.load(cache_path)
|
padding="max_length" if pad_to_max_length else None,
|
||||||
assert isinstance(examples, list)
|
truncation=True,
|
||||||
return examples
|
return_tensors=return_tensors,
|
||||||
|
**extra_kw,
|
||||||
except Exception:
|
)
|
||||||
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
|
||||||
data_path = Path(data_path)
|
|
||||||
|
|
||||||
lns = lmap(str.strip, data_path.open().readlines())
|
|
||||||
lns = [prefix + text for text in lns]
|
|
||||||
assert lns, f"found empty file at {data_path}"
|
|
||||||
examples = []
|
|
||||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
|
||||||
tokenized = tokenizer(
|
|
||||||
[text],
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length" if pad_to_max_length else None,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
**extra_kw,
|
|
||||||
)
|
|
||||||
assert tokenized.input_ids.shape[1] == max_length
|
|
||||||
examples.append(tokenized)
|
|
||||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def lmap(f: Callable, x: Iterable) -> List:
|
def lmap(f: Callable, x: Iterable) -> List:
|
||||||
@@ -80,73 +52,111 @@ def trim_batch(
|
|||||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||||
|
|
||||||
|
|
||||||
class SummarizationDataset(Dataset):
|
class Seq2SeqDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir,
|
data_dir,
|
||||||
|
max_source_length,
|
||||||
|
max_target_length,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
max_source_length=1024,
|
|
||||||
max_target_length=56,
|
|
||||||
n_obs=None,
|
n_obs=None,
|
||||||
overwrite_cache=False,
|
|
||||||
prefix="",
|
|
||||||
src_lang=None,
|
src_lang=None,
|
||||||
tgt_lang=None,
|
tgt_lang=None,
|
||||||
|
prefix="",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# FIXME: the rstrip logic strips all the chars, it seems.
|
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||||
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
self.src_lens = self.get_char_lens(self.src_file)
|
||||||
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
self.max_source_length = max_source_length
|
||||||
self.source = encode_file(
|
self.max_target_length = max_target_length
|
||||||
tokenizer,
|
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||||
os.path.join(data_dir, type_path + ".source"),
|
self.tokenizer = tokenizer
|
||||||
max_source_length,
|
self.prefix = prefix
|
||||||
overwrite_cache=overwrite_cache,
|
|
||||||
prefix=prefix,
|
|
||||||
tok_name=tok_name,
|
|
||||||
)
|
|
||||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
|
||||||
if hasattr(tokenizer, "set_lang"):
|
|
||||||
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
|
||||||
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
|
||||||
self.target = encode_file(
|
|
||||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
|
||||||
)
|
|
||||||
if n_obs is not None:
|
if n_obs is not None:
|
||||||
self.source = self.source[:n_obs]
|
self.src_lens = self.src_lens[:n_obs]
|
||||||
self.target = self.target[:n_obs]
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
self.src_lang = src_lang
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.source)
|
return len(self.src_lens)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||||
source_ids = self.source[index]["input_ids"].squeeze()
|
index = index + 1 # linecache starts at 1
|
||||||
target_ids = self.target[index]["input_ids"].squeeze()
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||||
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
assert source_line, f"empty source line for index {index}"
|
||||||
|
assert tgt_line, f"empty tgt line for index {index}"
|
||||||
|
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||||
|
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||||
|
|
||||||
|
source_ids = source_inputs["input_ids"].squeeze()
|
||||||
|
target_ids = target_inputs["input_ids"].squeeze()
|
||||||
|
src_mask = source_inputs["attention_mask"].squeeze()
|
||||||
|
return {
|
||||||
|
"input_ids": source_ids,
|
||||||
|
"attention_mask": src_mask,
|
||||||
|
"decoder_input_ids": target_ids,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trim_seq2seq_batch(batch, pad_token_id):
|
def get_char_lens(data_file):
|
||||||
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
|
||||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||||
return source_ids, source_mask, y
|
return source_ids, source_mask, y
|
||||||
|
|
||||||
def collate_fn(self, batch) -> dict:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||||
pad_token_id = self.pad_token_id
|
pad_token_id = self.pad_token_id
|
||||||
y = trim_batch(target_ids, pad_token_id)
|
y = trim_batch(target_ids, pad_token_id)
|
||||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
batch = {
|
||||||
|
"input_ids": source_ids,
|
||||||
|
"attention_mask": source_mask,
|
||||||
|
"decoder_input_ids": y,
|
||||||
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size):
|
def make_sortish_sampler(self, batch_size):
|
||||||
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
|
return SortishSampler(self.src_lens, batch_size)
|
||||||
return SortishSampler(lens, batch_size)
|
|
||||||
|
|
||||||
|
class MBartDataset(Seq2SeqDataset):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if self.max_source_length != self.max_target_length:
|
||||||
|
warnings.warn(
|
||||||
|
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> Dict[str, str]:
|
||||||
|
index = index + 1 # linecache starts at 1
|
||||||
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||||
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||||
|
assert source_line, f"empty source line for index {index}"
|
||||||
|
assert tgt_line, f"empty tgt line for index {index}"
|
||||||
|
return {
|
||||||
|
"tgt_texts": source_line,
|
||||||
|
"src_texts": tgt_line,
|
||||||
|
}
|
||||||
|
|
||||||
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
batch_encoding = self.tokenizer.prepare_translation_batch(
|
||||||
|
[x["src_texts"] for x in batch],
|
||||||
|
src_lang=self.src_lang,
|
||||||
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
|
tgt_lang=self.tgt_lang,
|
||||||
|
max_length=self.max_source_length,
|
||||||
|
)
|
||||||
|
return batch_encoding.data
|
||||||
|
|
||||||
|
|
||||||
class SortishSampler(Sampler):
|
class SortishSampler(Sampler):
|
||||||
|
|||||||
@@ -118,12 +118,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
self.reset_special_tokens()
|
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||||
|
|
||||||
def reset_special_tokens(self) -> None:
|
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
|
||||||
self.prefix_tokens = []
|
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
@@ -183,12 +178,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
def set_lang(self, lang: str) -> None:
|
|
||||||
"""Set the current language code in order to call tokenizer properly."""
|
|
||||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
|
||||||
self.prefix_tokens = [self.cur_lang_code]
|
|
||||||
self.suffix_tokens = [self.eos_token_id]
|
|
||||||
|
|
||||||
def prepare_translation_batch(
|
def prepare_translation_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
@@ -215,7 +204,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
"""
|
"""
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.max_len
|
max_length = self.max_len
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
self.set_src_lang_special_tokens(src_lang)
|
||||||
model_inputs: BatchEncoding = self(
|
model_inputs: BatchEncoding = self(
|
||||||
src_texts,
|
src_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@@ -227,7 +216,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
)
|
)
|
||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
self.set_lang(tgt_lang)
|
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||||
decoder_inputs: BatchEncoding = self(
|
decoder_inputs: BatchEncoding = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@@ -239,6 +228,18 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
)
|
)
|
||||||
for k, v in decoder_inputs.items():
|
for k, v in decoder_inputs.items():
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs[f"decoder_{k}"] = v
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
|
||||||
self.reset_special_tokens() # sets to src_lang
|
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
|
||||||
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|||||||
Reference in New Issue
Block a user