improve device usage
This commit is contained in:
committed by
Julien Chaumond
parent
c0707a85d2
commit
2a64107e44
@@ -185,7 +185,7 @@ def save_summaries(summaries, path, original_document_name):
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||
iterator = DataLoader(
|
||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||
)
|
||||
@@ -198,7 +198,7 @@ def load_and_cache_examples(args, tokenizer):
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size):
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
""" Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
@@ -224,9 +224,9 @@ def collate(data, tokenizer, block_size):
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories,
|
||||
segs=encoder_token_type_ids,
|
||||
mask_src=encoder_mask,
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
@@ -271,10 +271,10 @@ def main():
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--visible_gpus",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Number of GPUs with which to do the training.",
|
||||
"--to_cpu",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
@@ -311,8 +311,11 @@ def main():
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
|
||||
|
||||
# Select device (distibuted not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.to_cpu else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user