Merge branch 'master' into from_scratch_training
This commit is contained in:
@@ -42,6 +42,7 @@ class LmSeqsDataset(Dataset):
|
||||
self.check()
|
||||
self.remove_long_sequences()
|
||||
self.remove_empty_sequences()
|
||||
self.remove_unknown_sequences()
|
||||
self.check()
|
||||
self.print_statistics()
|
||||
|
||||
@@ -109,6 +110,22 @@ class LmSeqsDataset(Dataset):
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
|
||||
|
||||
def remove_unknown_sequences(self):
|
||||
"""
|
||||
Remove sequences with a (too) high level of unknown tokens.
|
||||
"""
|
||||
if "unk_token" not in self.params.special_tok_ids:
|
||||
return
|
||||
else:
|
||||
unk_token_id = self.params.special_tok_ids["unk_token"]
|
||||
init_size = len(self)
|
||||
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
||||
indices = (unk_occs / self.lengths) < 0.5
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
Print some statistics on the corpus. Only the master process.
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"activation": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"max_position_embeddings": 512,
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"sinusoidal_pos_embds": true,
|
||||
"tie_weights_": true,
|
||||
"vocab_size": 119547
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"vocab_size": 50265,
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers": 6,
|
||||
"num_attention_heads": 12,
|
||||
"intermediate_size": 3072,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 514,
|
||||
"type_vocab_size": 1,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 0.00001
|
||||
}
|
||||
@@ -344,6 +344,7 @@ def full_text_generation(
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
@@ -368,7 +369,14 @@ def full_text_generation(
|
||||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
length=length,
|
||||
sample=sample,
|
||||
perturb=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
@@ -401,6 +409,7 @@ def full_text_generation(
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
@@ -437,6 +446,7 @@ def generate_text_pplm(
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
output_so_far = None
|
||||
if context:
|
||||
@@ -508,6 +518,13 @@ def generate_text_pplm(
|
||||
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
for token_idx in set(output_so_far[0].tolist()):
|
||||
if pert_logits[0, token_idx] < 0:
|
||||
pert_logits[0, token_idx] *= repetition_penalty
|
||||
else:
|
||||
pert_logits[0, token_idx] /= repetition_penalty
|
||||
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
@@ -588,6 +605,7 @@ def run_pplm_example(
|
||||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
# set Random seed
|
||||
torch.manual_seed(seed)
|
||||
@@ -655,6 +673,7 @@ def run_pplm_example(
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# untokenize unperturbed text
|
||||
@@ -767,6 +786,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument(
|
||||
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pplm_example(**vars(args))
|
||||
|
||||
Reference in New Issue
Block a user