pad sequence with 0, mask with -1

This commit is contained in:
Rémi Louf
2019-10-17 17:44:20 +02:00
parent dc580dd4c7
commit b915ba9dfe

View File

@@ -58,7 +58,7 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512): def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
@@ -165,7 +165,12 @@ def _fit_to_block_size(sequence, block_size):
if len(sequence) > block_size: if len(sequence) > block_size:
return sequence[:block_size] return sequence[:block_size]
else: else:
return sequence.extend([-1] * (block_size - len(sequence))) return sequence.extend([0] * (block_size - len(sequence)))
def mask_padding_tokens(sequence):
""" Replace the padding token with -1 values """
return [s if s != 0 else -1 for s in sequence]
def load_and_cache_examples(args, tokenizer): def load_and_cache_examples(args, tokenizer):
@@ -219,11 +224,8 @@ def train(args, train_dataset, model, tokenizer):
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info( logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size args.train_batch_size
* args.gradient_accumulation_steps * args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1), * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
@@ -242,7 +244,7 @@ def train(args, train_dataset, model, tokenizer):
source = ([s for s, _ in batch]).to(args.device) source = ([s for s, _ in batch]).to(args.device)
target = ([t for _, t in batch]).to(args.device) target = ([t for _, t in batch]).to(args.device)
model.train() model.train()
outputs = model(source, target) outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
loss = outputs[0] loss = outputs[0]
loss.backward() loss.backward()