[s2s] distributed eval cleanup (#7186)
This commit is contained in:
@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
|
|||||||
--fp16 \
|
--fp16 \
|
||||||
--bs 32
|
--bs 32
|
||||||
```
|
```
|
||||||
|
### Multi-GPU Evalulation
|
||||||
|
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
|
||||||
|
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
|
||||||
|
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
|
||||||
|
--model_name sshleifer/distilbart-large-xsum-12-3 \
|
||||||
|
--save_dir xsum_generations \
|
||||||
|
--data_dir xsum \
|
||||||
|
--fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Contributions that implement this command for other distributed hardware setups are welcome!
|
||||||
|
|
||||||
#### run_eval tips and tricks
|
#### run_eval tips and tricks
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -22,7 +22,7 @@ try:
|
|||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
lmap,
|
lmap,
|
||||||
load_json,
|
load_json,
|
||||||
parse_numeric_cl_kwargs,
|
parse_numeric_n_bool_cl_kwargs,
|
||||||
save_json,
|
save_json,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
write_txt_file,
|
write_txt_file,
|
||||||
@@ -34,7 +34,7 @@ except ImportError:
|
|||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
lmap,
|
lmap,
|
||||||
load_json,
|
load_json,
|
||||||
parse_numeric_cl_kwargs,
|
parse_numeric_n_bool_cl_kwargs,
|
||||||
save_json,
|
save_json,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
write_txt_file,
|
write_txt_file,
|
||||||
@@ -50,7 +50,6 @@ def eval_data_dir(
|
|||||||
type_path="val",
|
type_path="val",
|
||||||
n_obs=None,
|
n_obs=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
num_beams: int = 4,
|
|
||||||
task="summarization",
|
task="summarization",
|
||||||
local_rank=None,
|
local_rank=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
@@ -81,23 +80,21 @@ def eval_data_dir(
|
|||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
prefix=model.config.prefix,
|
prefix=model.config.prefix,
|
||||||
)
|
)
|
||||||
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
|
# I set shuffle=True for a more accurate progress bar.
|
||||||
|
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
||||||
|
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
|
||||||
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
|
||||||
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
|
|
||||||
results = []
|
results = []
|
||||||
for batch in tqdm(data_loader):
|
for batch in tqdm(data_loader):
|
||||||
summaries = model.generate(
|
summaries = model.generate(
|
||||||
input_ids=batch["input_ids"].to(model.device),
|
input_ids=batch["input_ids"].to(model.device),
|
||||||
attention_mask=batch["attention_mask"].to(model.device),
|
attention_mask=batch["attention_mask"].to(model.device),
|
||||||
num_beams=num_beams,
|
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
|
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
|
|
||||||
ids = batch["ids"]
|
ids = batch["ids"]
|
||||||
for i in range(len(labels)):
|
for i, pred in enumerate(preds):
|
||||||
label, pred = labels[i], preds[i]
|
results.append(dict(pred=pred, id=ids[i].item()))
|
||||||
results.append(dict(pred=pred, label=label, id=ids[i].item()))
|
|
||||||
save_json(results, save_path)
|
save_json(results, save_path)
|
||||||
return results, sampler.num_replicas
|
return results, sampler.num_replicas
|
||||||
|
|
||||||
@@ -139,8 +136,8 @@ def run_generate():
|
|||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
generate_kwargs = parse_numeric_cl_kwargs(rest)
|
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
|
||||||
if generate_kwargs:
|
if generate_kwargs and args.local_rank <= 0:
|
||||||
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
print(f"parsed the following generate kwargs: {generate_kwargs}")
|
||||||
json_save_dir = Path(args.save_dir + "_tmp")
|
json_save_dir = Path(args.save_dir + "_tmp")
|
||||||
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
|
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
|
||||||
@@ -168,7 +165,10 @@ def run_generate():
|
|||||||
save_dir = Path(args.save_dir)
|
save_dir = Path(args.save_dir)
|
||||||
save_dir.mkdir(exist_ok=True)
|
save_dir.mkdir(exist_ok=True)
|
||||||
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
|
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
|
||||||
preds, labels = combine_partial_results(partial_results)
|
preds = combine_partial_results(partial_results)
|
||||||
|
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
|
||||||
|
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
|
||||||
|
|
||||||
# Calculate metrics, save metrics, and save _generations.txt
|
# Calculate metrics, save metrics, and save _generations.txt
|
||||||
calc_bleu = "translation" in args.task
|
calc_bleu = "translation" in args.task
|
||||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||||
@@ -179,7 +179,7 @@ def run_generate():
|
|||||||
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
|
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
|
||||||
# TODO(@stas00): add whatever metadata to metrics
|
# TODO(@stas00): add whatever metadata to metrics
|
||||||
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
|
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
|
||||||
save_json(metrics, metrics_save_path)
|
save_json(metrics, metrics_save_path, indent=None)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
|
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
|
||||||
if args.debug:
|
if args.debug:
|
||||||
@@ -188,15 +188,14 @@ def run_generate():
|
|||||||
shutil.rmtree(json_save_dir)
|
shutil.rmtree(json_save_dir)
|
||||||
|
|
||||||
|
|
||||||
def combine_partial_results(partial_results) -> Tuple[List, List]:
|
def combine_partial_results(partial_results) -> List:
|
||||||
"""Concatenate partial results into one file, then sort it by id."""
|
"""Concatenate partial results into one file, then sort it by id."""
|
||||||
records = []
|
records = []
|
||||||
for partial_result in partial_results:
|
for partial_result in partial_results:
|
||||||
records.extend(partial_result)
|
records.extend(partial_result)
|
||||||
records = list(sorted(records, key=lambda x: x["id"]))
|
records = list(sorted(records, key=lambda x: x["id"]))
|
||||||
preds = [x["pred"] for x in records]
|
preds = [x["pred"] for x in records]
|
||||||
labels = [x["label"] for x in records]
|
return preds
|
||||||
return preds, labels
|
|
||||||
|
|
||||||
|
|
||||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ def run_generate(verbose=True):
|
|||||||
scores["info"] = args.info
|
scores["info"] = args.info
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(*scores)
|
print(scores)
|
||||||
|
|
||||||
if args.score_path is not None:
|
if args.score_path is not None:
|
||||||
path = args.score_path
|
path = args.score_path
|
||||||
|
|||||||
@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||||||
def get_char_lens(data_file):
|
def get_char_lens(data_file):
|
||||||
return [len(x) for x in Path(data_file).open().readlines()]
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
|
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
||||||
if distributed:
|
if distributed:
|
||||||
return DistributedSortishSampler(self, batch_size, **kwargs)
|
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
||||||
else:
|
else:
|
||||||
return SortishSampler(self.src_lens, batch_size)
|
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
raise NotImplementedError("You must implement this")
|
raise NotImplementedError("You must implement this")
|
||||||
@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||||||
class SortishSampler(Sampler):
|
class SortishSampler(Sampler):
|
||||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||||
|
|
||||||
def __init__(self, data, batch_size):
|
def __init__(self, data, batch_size, shuffle=True):
|
||||||
self.data, self.bs = data, batch_size
|
self.data, self.bs, self.shuffle = data, batch_size, shuffle
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(sortish_sampler_indices(self.data, self.bs))
|
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
||||||
|
|
||||||
|
|
||||||
def sortish_sampler_indices(data: List, bs: int) -> np.array:
|
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
||||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||||
|
if not shuffle:
|
||||||
|
return np.argsort(np.array(data) * -1)
|
||||||
|
|
||||||
def key_fn(i):
|
def key_fn(i):
|
||||||
return data[i]
|
return data[i]
|
||||||
@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
|
|||||||
class DistributedSortishSampler(Sampler):
|
class DistributedSortishSampler(Sampler):
|
||||||
"""Copied from torch DistributedSampler"""
|
"""Copied from torch DistributedSampler"""
|
||||||
|
|
||||||
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
|
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
|
|||||||
self.num_samples = len(self.available_indices)
|
self.num_samples = len(self.available_indices)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.add_extra_examples = add_extra_examples
|
self.add_extra_examples = add_extra_examples
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
def __iter__(self) -> Iterable:
|
def __iter__(self) -> Iterable:
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(self.epoch)
|
g.manual_seed(self.epoch)
|
||||||
|
|
||||||
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
||||||
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
|
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
|
||||||
indices = [self.available_indices[i] for i in sortish_indices]
|
indices = [self.available_indices[i] for i in sortish_indices]
|
||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
|
|||||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||||
|
|
||||||
|
|
||||||
def save_json(content, path):
|
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(content, f, indent=4)
|
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_json(path):
|
def load_json(path):
|
||||||
|
|||||||
Reference in New Issue
Block a user