Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
@@ -7,6 +7,15 @@ For `bertabs` instructions, see `bertabs/README.md`.
|
||||
|
||||
|
||||
### Data
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
CNN/DailyMail data
|
||||
```bash
|
||||
@@ -17,18 +26,6 @@ tar -xzvf cnn_dm.tgz
|
||||
export CNN_DIR=${PWD}/cnn_dm
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
|
||||
|
||||
WMT16 English-Romanian Translation Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
@@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro
|
||||
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||
The `.source` files are the input, the `.target` files are the desired output.
|
||||
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
|
||||
General Tips:
|
||||
@@ -64,6 +61,10 @@ Summarization Tips:
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
|
||||
**Update 2018-07-18**
|
||||
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
||||
A new dataset is needed to support multilingual tasks.
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
|
||||
--model_name_or_path facebook/bart-large
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Translation Finetuning
|
||||
|
||||
First, follow the wmt_en_ro download instructions.
|
||||
@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
#### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
|
||||
|
||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||
```bash
|
||||
WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir xsum_frozen_embs \
|
||||
--model_name_or_path facebook/bart-large \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--logger_name wandb
|
||||
```
|
||||
|
||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||
|
||||
### Evaluation Commands
|
||||
|
||||
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
||||
|
||||
@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
||||
return dataset
|
||||
|
||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
|
||||
@@ -21,7 +21,6 @@ try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
@@ -32,12 +31,17 @@ try:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
@@ -48,7 +52,6 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
assert_all_frozen,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
self.dataset_class = Seq2SeqDataset
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = SummarizationDataset(
|
||||
dataset = self.dataset_class(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
if isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.dataset_class = MBartDataset
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
@@ -9,16 +9,17 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest import param
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, MBartTokenizer
|
||||
from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import SummarizationDataset, lmap, load_json
|
||||
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"label_smoothing_eps": 0.2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
@@ -80,11 +82,11 @@ CHEAP_ARGS = {
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
with path.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
content = "\n".join(articles)
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
@@ -260,22 +262,50 @@ def test_pack_dataset():
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_dataset(tok):
|
||||
def test_mbart_dataset_truncation():
|
||||
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc = 4
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||
train_dataset = MBartDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=trunc,
|
||||
max_target_length=1000, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == trunc
|
||||
# show that targets are the same len
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc
|
||||
# check language codes in correct place
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
assert max_len_target > trunc # Truncated
|
||||
assert max_len_source > trunc
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
||||
def test_summarization_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
tgt_lang="ro_RO",
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
@@ -286,3 +316,4 @@ def test_dataset(tok):
|
||||
# show that targets were truncated
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
data_path,
|
||||
max_length,
|
||||
pad_to_max_length=True,
|
||||
return_tensors="pt",
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
examples = torch.load(cache_path)
|
||||
assert isinstance(examples, list)
|
||||
return examples
|
||||
|
||||
except Exception:
|
||||
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
||||
data_path = Path(data_path)
|
||||
|
||||
lns = lmap(str.strip, data_path.open().readlines())
|
||||
lns = [prefix + text for text in lns]
|
||||
assert lns, f"found empty file at {data_path}"
|
||||
examples = []
|
||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||
tokenized = tokenizer(
|
||||
[text],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||
return examples
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
@@ -80,73 +52,111 @@ def trim_batch(
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
max_source_length=1024,
|
||||
max_target_length=56,
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
# FIXME: the rstrip logic strips all the chars, it seems.
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
||||
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
max_source_length,
|
||||
overwrite_cache=overwrite_cache,
|
||||
prefix=prefix,
|
||||
tok_name=tok_name,
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
||||
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
if n_obs is not None:
|
||||
self.source = self.source[:n_obs]
|
||||
self.target = self.target[:n_obs]
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source)
|
||||
return len(self.src_lens)
|
||||
|
||||
def __getitem__(self, index):
|
||||
source_ids = self.source[index]["input_ids"].squeeze()
|
||||
target_ids = self.target[index]["input_ids"].squeeze()
|
||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id):
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
|
||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||
return source_ids, source_mask, y
|
||||
|
||||
def collate_fn(self, batch) -> dict:
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
pad_token_id = self.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
|
||||
return SortishSampler(lens, batch_size)
|
||||
return SortishSampler(self.src_lens, batch_size)
|
||||
|
||||
|
||||
class MBartDataset(Seq2SeqDataset):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.max_source_length != self.max_target_length:
|
||||
warnings.warn(
|
||||
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
|
||||
)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {
|
||||
"tgt_texts": source_line,
|
||||
"src_texts": tgt_line,
|
||||
}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_translation_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.tgt_lang,
|
||||
max_length=self.max_source_length,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
|
||||
Reference in New Issue
Block a user