Update naming + remove f string in run_lm_finetuning example
This commit is contained in:
@@ -59,7 +59,7 @@ class TextDataset(Dataset):
|
||||
def __init__(self, tokenizer, file_path='train', block_size=512):
|
||||
assert os.path.isfile(file_path)
|
||||
directory, filename = os.path.split(file_path)
|
||||
cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename))
|
||||
cached_features_file = os.path.join(directory, 'cached_lm_' + block_size + '_' + filename)
|
||||
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
@@ -110,7 +110,7 @@ def mask_tokens(inputs, tokenizer, args):
|
||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
||||
probability_matrix *= torch.tensor(
|
||||
[tokenizer.get_sequence_ids(val, special_tokens_present=True) for val in labels.tolist()],
|
||||
[tokenizer.get_special_tokens_mask(val, special_tokens_present=True) for val in labels.tolist()],
|
||||
dtype=torch.float
|
||||
)
|
||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||
|
||||
Reference in New Issue
Block a user