good quality generation example for GPT, GPT-2, Transfo-XL, XLNet
This commit is contained in:
@@ -97,7 +97,6 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
inp_q = target_mapping[:, 0, :].clone() # predict last token
|
||||
|
||||
sequence_labels = None
|
||||
lm_labels = None
|
||||
@@ -124,14 +123,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
num_labels=self.type_sequence_label_size)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetModel(config)
|
||||
model.eval()
|
||||
|
||||
@@ -153,7 +152,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
@@ -161,7 +160,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
||||
|
||||
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
|
||||
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
|
||||
result = {
|
||||
"loss_1": loss_1,
|
||||
@@ -193,7 +192,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
model.eval()
|
||||
|
||||
@@ -243,7 +242,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetForSequenceClassification(config)
|
||||
model.eval()
|
||||
|
||||
@@ -269,7 +268,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user