[s2s] distributed eval allows num_return_sequences > 1 (#7254)
This commit is contained in:
@@ -13,7 +13,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
from utils import calculate_bleu, calculate_rouge, chunks, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
@@ -22,12 +22,6 @@ logger = getLogger(__name__)
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries_or_translations(
|
||||
examples: List[str],
|
||||
out_file: str,
|
||||
|
||||
Reference in New Issue
Block a user