[s2s] distributed eval in one command (#7124)
This commit is contained in:
@@ -1,46 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import fire
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
|
||||||
except ImportError:
|
|
||||||
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
|
||||||
|
|
||||||
|
|
||||||
def combine_partial_results(
|
|
||||||
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
|
|
||||||
):
|
|
||||||
"""Write first n lines of each file f in src_dir to dest_dir/f """
|
|
||||||
src_dir = Path(result_dir)
|
|
||||||
save_dir = Path(save_dir)
|
|
||||||
save_dir.mkdir(exist_ok=True)
|
|
||||||
paths_to_combine = list(src_dir.glob("rank*.json"))
|
|
||||||
records = []
|
|
||||||
for partial_result in paths_to_combine:
|
|
||||||
records.extend(load_json(partial_result))
|
|
||||||
preds = [x["pred"] for x in records]
|
|
||||||
labels = [x["label"] for x in records]
|
|
||||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
|
||||||
metrics = score_fn(preds, labels)
|
|
||||||
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
|
|
||||||
print(metrics)
|
|
||||||
if just_metrics:
|
|
||||||
return
|
|
||||||
|
|
||||||
if save_prefix is None:
|
|
||||||
save_prefix = "generated"
|
|
||||||
print("using generated as prefix")
|
|
||||||
|
|
||||||
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
|
|
||||||
write_txt_file(labels, tgt_path)
|
|
||||||
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
|
|
||||||
write_txt_file(preds, pred_path)
|
|
||||||
if "source" in records[0]:
|
|
||||||
src_path = save_dir.joinpath(f"{save_prefix}.source")
|
|
||||||
write_txt_file([x["source"] for x in records], src_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fire.Fire(combine_partial_results)
|
|
||||||
@@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process.
|
|||||||
cd $HOME
|
cd $HOME
|
||||||
git clone git@github.com:moses-smt/mosesdecoder.git
|
git clone git@github.com:moses-smt/mosesdecoder.git
|
||||||
cd mosesdecoder
|
cd mosesdecoder
|
||||||
git@github.com:rsennrich/wmt16-scripts.git
|
git clone git@github.com:rsennrich/wmt16-scripts.git
|
||||||
```
|
```
|
||||||
|
|
||||||
(2) define a function for post processing.
|
(2) define a function for post processing.
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from json import JSONDecodeError
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
from .utils import (
|
||||||
|
Seq2SeqDataset,
|
||||||
|
calculate_bleu,
|
||||||
|
calculate_rouge,
|
||||||
|
lmap,
|
||||||
|
load_json,
|
||||||
|
parse_numeric_cl_kwargs,
|
||||||
|
save_json,
|
||||||
|
use_task_specific_params,
|
||||||
|
write_txt_file,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
|
from utils import (
|
||||||
|
Seq2SeqDataset,
|
||||||
|
calculate_bleu,
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
calculate_rouge,
|
||||||
|
lmap,
|
||||||
|
load_json,
|
||||||
|
parse_numeric_cl_kwargs,
|
||||||
|
save_json,
|
||||||
|
use_task_specific_params,
|
||||||
|
write_txt_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def eval_data_dir(
|
def eval_data_dir(
|
||||||
@@ -30,7 +50,6 @@ def eval_data_dir(
|
|||||||
type_path="val",
|
type_path="val",
|
||||||
n_obs=None,
|
n_obs=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
save_source=False,
|
|
||||||
num_beams: int = 4,
|
num_beams: int = 4,
|
||||||
task="summarization",
|
task="summarization",
|
||||||
local_rank=None,
|
local_rank=None,
|
||||||
@@ -62,7 +81,7 @@ def eval_data_dir(
|
|||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
prefix=model.config.prefix,
|
prefix=model.config.prefix,
|
||||||
)
|
)
|
||||||
sampler = ds.make_sortish_sampler(bs, distributed=True)
|
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
|
||||||
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
||||||
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
||||||
results = []
|
results = []
|
||||||
@@ -75,23 +94,19 @@ def eval_data_dir(
|
|||||||
)
|
)
|
||||||
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
||||||
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
||||||
if save_source:
|
ids = batch["ids"]
|
||||||
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
|
|
||||||
for i in range(len(labels)):
|
for i in range(len(labels)):
|
||||||
label, pred = labels[i], preds[i]
|
label, pred = labels[i], preds[i]
|
||||||
if save_source:
|
results.append(dict(pred=pred, label=label, id=ids[i].item()))
|
||||||
results.append(dict(pred=pred, label=label, source=docs[i]))
|
|
||||||
else:
|
|
||||||
results.append(dict(pred=pred, label=label))
|
|
||||||
save_json(results, save_path)
|
save_json(results, save_path)
|
||||||
return results
|
return results, sampler.num_replicas
|
||||||
|
|
||||||
|
|
||||||
def run_generate():
|
def run_generate():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
|
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
|
||||||
)
|
)
|
||||||
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
|
parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name",
|
"--model_name",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -113,17 +128,31 @@ def run_generate():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sync_timeout",
|
||||||
|
type=int,
|
||||||
|
default=600,
|
||||||
|
required=False,
|
||||||
|
help="How long should master process wait for other processes to finish.",
|
||||||
|
)
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
parser.add_argument("--save_source", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
|
start_time = time.time()
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
||||||
if generate_kwargs:
|
if generate_kwargs:
|
||||||
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
||||||
|
json_save_dir = Path(args.save_dir + "_tmp")
|
||||||
|
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
|
||||||
|
intermediate_files = list(json_save_dir.glob("rank_*.json"))
|
||||||
|
if intermediate_files:
|
||||||
|
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
|
||||||
|
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
|
||||||
|
|
||||||
Path(args.save_dir).mkdir(exist_ok=True)
|
Path(args.save_dir).mkdir(exist_ok=True)
|
||||||
eval_data_dir(
|
results, num_replicas = eval_data_dir(
|
||||||
args.input_path,
|
args.data_dir,
|
||||||
args.save_dir,
|
json_save_dir,
|
||||||
args.model_name,
|
args.model_name,
|
||||||
type_path=args.type_path,
|
type_path=args.type_path,
|
||||||
batch_size=args.bs,
|
batch_size=args.bs,
|
||||||
@@ -131,11 +160,64 @@ def run_generate():
|
|||||||
task=args.task,
|
task=args.task,
|
||||||
local_rank=args.local_rank,
|
local_rank=args.local_rank,
|
||||||
n_obs=args.n_obs,
|
n_obs=args.n_obs,
|
||||||
save_source=args.save_source,
|
|
||||||
max_source_length=args.max_source_length,
|
max_source_length=args.max_source_length,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.local_rank <= 0:
|
||||||
|
save_dir = Path(args.save_dir)
|
||||||
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
|
||||||
|
preds, labels = combine_partial_results(partial_results)
|
||||||
|
# Calculate metrics, save metrics, and save _generations.txt
|
||||||
|
calc_bleu = "translation" in args.task
|
||||||
|
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||||
|
metric_name = "bleu" if calc_bleu else "rouge"
|
||||||
|
metrics: Dict = score_fn(preds, labels)
|
||||||
|
metrics["n_obs"] = len(preds)
|
||||||
|
runtime = time.time() - start_time
|
||||||
|
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
|
||||||
|
# TODO(@stas00): add whatever metadata to metrics
|
||||||
|
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
|
||||||
|
save_json(metrics, metrics_save_path)
|
||||||
|
print(metrics)
|
||||||
|
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
|
||||||
|
if args.debug:
|
||||||
|
write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
|
||||||
|
else:
|
||||||
|
shutil.rmtree(json_save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_partial_results(partial_results) -> Tuple[List, List]:
|
||||||
|
"""Concatenate partial results into one file, then sort it by id."""
|
||||||
|
records = []
|
||||||
|
for partial_result in partial_results:
|
||||||
|
records.extend(partial_result)
|
||||||
|
records = list(sorted(records, key=lambda x: x["id"]))
|
||||||
|
preds = [x["pred"] for x in records]
|
||||||
|
labels = [x["label"] for x in records]
|
||||||
|
return preds, labels
|
||||||
|
|
||||||
|
|
||||||
|
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||||
|
# WAIT FOR lots of .json files
|
||||||
|
start_wait = time.time()
|
||||||
|
logger.info("waiting for all nodes to finish")
|
||||||
|
json_data = None
|
||||||
|
while (time.time() - start_wait) < timeout:
|
||||||
|
json_files = list(save_dir.glob("rank_*.json"))
|
||||||
|
if len(json_files) < num_replicas:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# make sure all json files are fully saved
|
||||||
|
json_data = lmap(load_json, json_files)
|
||||||
|
return json_data
|
||||||
|
except JSONDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise TimeoutError("Rank 0 gave up on waiting for other processes")
|
||||||
|
# Unreachable
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Usage for MT:
|
# Usage for MT:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from torch import nn
|
|||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
|
||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
|
||||||
|
|
||||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||||
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||||||
def get_char_lens(data_file):
|
def get_char_lens(data_file):
|
||||||
return [len(x) for x in Path(data_file).open().readlines()]
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size, distributed=False):
|
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
|
||||||
if distributed:
|
if distributed:
|
||||||
return DistributedSortishSampler(self, batch_size)
|
return DistributedSortishSampler(self, batch_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
return SortishSampler(self.src_lens, batch_size)
|
return SortishSampler(self.src_lens, batch_size)
|
||||||
|
|
||||||
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||||
assert source_line, f"empty source line for index {index}"
|
assert source_line, f"empty source line for index {index}"
|
||||||
assert tgt_line, f"empty tgt line for index {index}"
|
assert tgt_line, f"empty tgt line for index {index}"
|
||||||
return {
|
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
||||||
"tgt_texts": tgt_line,
|
|
||||||
"src_texts": source_line,
|
|
||||||
}
|
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
"""Call prepare_seq2seq_batch."""
|
"""Call prepare_seq2seq_batch."""
|
||||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||||
[x["src_texts"] for x in batch],
|
[x["src_texts"] for x in batch],
|
||||||
src_lang=self.src_lang,
|
src_lang=self.src_lang,
|
||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||||||
max_target_length=self.max_target_length,
|
max_target_length=self.max_target_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_prefix_space=self.add_prefix_space,
|
add_prefix_space=self.add_prefix_space,
|
||||||
)
|
).data
|
||||||
return batch_encoding.data
|
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
||||||
|
return batch_encoding
|
||||||
|
|
||||||
|
|
||||||
class SortishSampler(Sampler):
|
class SortishSampler(Sampler):
|
||||||
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
|
|||||||
class DistributedSortishSampler(Sampler):
|
class DistributedSortishSampler(Sampler):
|
||||||
"""Copied from torch DistributedSampler"""
|
"""Copied from torch DistributedSampler"""
|
||||||
|
|
||||||
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
|
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
|
|||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
if add_extra_examples:
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
else:
|
||||||
|
self.total_size = len(dataset)
|
||||||
|
self.num_samples = len(self.available_indices)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.add_extra_examples = add_extra_examples
|
||||||
|
|
||||||
def __iter__(self) -> Iterable:
|
def __iter__(self) -> Iterable:
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(self.epoch)
|
g.manual_seed(self.epoch)
|
||||||
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
|
|
||||||
|
|
||||||
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
|
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
||||||
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
|
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
|
||||||
indices = [available_indices[i] for i in sortish_indices]
|
indices = [self.available_indices[i] for i in sortish_indices]
|
||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
def get_indices_for_rank(self) -> np.array:
|
@cached_property
|
||||||
|
def available_indices(self) -> np.array:
|
||||||
indices = list(range(len(self.dataset)))
|
indices = list(range(len(self.dataset)))
|
||||||
# add extra samples to make it evenly divisible
|
# add extra samples to make it evenly divisible
|
||||||
indices += indices[: (self.total_size - len(indices))]
|
indices += indices[: (self.total_size - len(indices))]
|
||||||
|
|||||||
Reference in New Issue
Block a user