generate_text_pplm now works with batch_size > 1
This commit is contained in:
committed by
Julien Chaumond
parent
893d0d64fe
commit
a59fdd1627
@@ -231,7 +231,8 @@ def perturb_past(
|
|||||||
prediction = classifier(new_accumulated_hidden /
|
prediction = classifier(new_accumulated_hidden /
|
||||||
(curr_length + 1 + horizon_length))
|
(curr_length + 1 + horizon_length))
|
||||||
|
|
||||||
label = torch.tensor([class_label], device=device,
|
label = torch.tensor(prediction.shape[0] * [class_label],
|
||||||
|
device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
discrim_loss = ce_loss(prediction, label)
|
discrim_loss = ce_loss(prediction, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
@@ -508,11 +509,12 @@ def generate_text_pplm(
|
|||||||
gm_scale=0.9,
|
gm_scale=0.9,
|
||||||
kl_scale=0.01,
|
kl_scale=0.01,
|
||||||
):
|
):
|
||||||
output_so_far = (
|
output_so_far = None
|
||||||
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
|
if context:
|
||||||
if context
|
context_t = torch.tensor(context, device=device, dtype=torch.long)
|
||||||
else None
|
while len(context_t.shape) < 2:
|
||||||
)
|
context_t = context_t.unsqueeze(0)
|
||||||
|
output_so_far = context_t
|
||||||
|
|
||||||
# collect one hot vectors for bags of words
|
# collect one hot vectors for bags of words
|
||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user