improve device usage

This commit is contained in:
Rémi Louf
2019-12-06 15:45:09 +01:00
committed by Julien Chaumond
parent c0707a85d2
commit 2a64107e44
4 changed files with 21 additions and 19 deletions

View File

@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name):
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader(
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
)
@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer):
return dataset
def collate(data, tokenizer, block_size):
def collate(data, tokenizer, block_size, device):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size):
batch = Batch(
document_names=names,
batch_size=len(encoded_stories),
src=encoded_stories,
segs=encoder_token_type_ids,
mask_src=encoder_mask,
src=encoded_stories.to(device),
segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask.to(device),
tgt_str=summaries,
)
@@ -271,10 +271,10 @@ def main():
)
# EVALUATION options
parser.add_argument(
"--visible_gpus",
default=-1,
type=int,
help="Number of GPUs with which to do the training.",
"--to_cpu",
default=False,
type=bool,
help="Whether to force the execution on CPU.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
@@ -311,8 +311,11 @@ def main():
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
)
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
# Select device (distibuted not available)
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.to_cpu else "cpu")
# Check the existence of directories
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir