Fix E731 flake8 warning (x3).

This commit is contained in:
Aymeric Augustin
2019-12-21 18:01:54 +01:00
parent eed46f38b7
commit 7dce8dc7ac
3 changed files with 8 additions and 3 deletions

View File

@@ -184,7 +184,10 @@ def save_rouge_scores(str_scores):
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
def collate_fn(data):
return collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
return iterator