pad sequence with 0, mask with -1
This commit is contained in:
@@ -58,7 +58,7 @@ class TextDataset(Dataset):
|
|||||||
[2] https://github.com/abisee/cnn-dailymail/
|
[2] https://github.com/abisee/cnn-dailymail/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512):
|
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
|
||||||
assert os.path.isdir(data_dir)
|
assert os.path.isdir(data_dir)
|
||||||
|
|
||||||
# Load features that have already been computed if present
|
# Load features that have already been computed if present
|
||||||
@@ -165,7 +165,12 @@ def _fit_to_block_size(sequence, block_size):
|
|||||||
if len(sequence) > block_size:
|
if len(sequence) > block_size:
|
||||||
return sequence[:block_size]
|
return sequence[:block_size]
|
||||||
else:
|
else:
|
||||||
return sequence.extend([-1] * (block_size - len(sequence)))
|
return sequence.extend([0] * (block_size - len(sequence)))
|
||||||
|
|
||||||
|
|
||||||
|
def mask_padding_tokens(sequence):
|
||||||
|
""" Replace the padding token with -1 values """
|
||||||
|
return [s if s != 0 else -1 for s in sequence]
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer):
|
def load_and_cache_examples(args, tokenizer):
|
||||||
@@ -219,11 +224,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||||
logger.info(
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||||
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
|
||||||
args.train_batch_size
|
args.train_batch_size
|
||||||
* args.gradient_accumulation_steps
|
* args.gradient_accumulation_steps
|
||||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||||
@@ -242,7 +244,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
source = ([s for s, _ in batch]).to(args.device)
|
source = ([s for s, _ in batch]).to(args.device)
|
||||||
target = ([t for _, t in batch]).to(args.device)
|
target = ([t for _, t in batch]).to(args.device)
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(source, target)
|
outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user