Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -25,9 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
||||
)
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
@@ -48,13 +46,14 @@ def evaluate(args):
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=['rouge-n', 'rouge-l'],
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type='words',
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
@@ -161,15 +160,15 @@ Recall >> {:.3f}
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores['rouge-1']['f'],
|
||||
scores['rouge-1']['p'],
|
||||
scores['rouge-1']['r'],
|
||||
scores['rouge-2']['f'],
|
||||
scores['rouge-2']['p'],
|
||||
scores['rouge-2']['r'],
|
||||
scores['rouge-l']['f'],
|
||||
scores['rouge-l']['p'],
|
||||
scores['rouge-l']['r'],
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
@@ -187,9 +186,7 @@ def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||
iterator = DataLoader(
|
||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||
)
|
||||
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||
|
||||
return iterator
|
||||
|
||||
@@ -210,14 +207,9 @@ def collate(data, tokenizer, block_size, device):
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [
|
||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||
]
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||
for story, _ in encoded_text
|
||||
]
|
||||
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
@@ -272,38 +264,23 @@ def main():
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
|
||||
Reference in New Issue
Block a user