standardizing API across models - XLNetForSeqClass working
This commit is contained in:
@@ -67,6 +67,8 @@ def main():
|
|||||||
help="The initial learning rate for Adam.")
|
help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--max_steps", default=-1, type=int,
|
||||||
|
help="If > 0 limit the number of training steps to perform, you should choose only one of num_train_epochs and max_steps.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training to perform linear learning rate warmup for. "
|
help="Proportion of training to perform linear learning rate warmup for. "
|
||||||
"E.g., 0.1 = 10%% of training.")
|
"E.g., 0.1 = 10%% of training.")
|
||||||
@@ -189,8 +191,7 @@ def main():
|
|||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
nb_tr_steps = 0
|
curr_tr_loss, curr_steps = 0., 1
|
||||||
tr_loss = 0
|
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -229,12 +230,15 @@ def main():
|
|||||||
|
|
||||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
train_sampler = RandomSampler(train_data)
|
train_sampler = SequentialSampler(train_data) # RandomSampler(train_data)
|
||||||
else:
|
else:
|
||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
if args.max_steps > 0:
|
||||||
|
num_train_optimization_steps = args.max_steps
|
||||||
|
else:
|
||||||
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
|
|
||||||
@@ -275,22 +279,16 @@ def main():
|
|||||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
for _ in trange(int(args.num_train_epochs) if args.max_steps <= 0 else int('Inf'),
|
||||||
tr_loss = 0
|
desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||||
nb_tr_examples, nb_tr_steps = 0, 0
|
for step, batch in enumerate(tqdm(train_dataloader,
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
desc="Iteration",
|
||||||
|
disable=args.local_rank not in [-1, 0])):
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# define a new function to compute loss values for both output_modes
|
# define a new function to compute loss values for both output_modes
|
||||||
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
loss, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||||
|
|
||||||
if output_mode == "classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
|
||||||
elif output_mode == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
|
||||||
|
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu.
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
@@ -302,12 +300,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if args.clip_gradients > 0.0:
|
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
|
|
||||||
|
|
||||||
tr_loss += loss.item()
|
curr_tr_loss += loss.item()
|
||||||
nb_tr_examples += input_ids.size(0)
|
curr_steps += 1
|
||||||
nb_tr_steps += 1
|
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
# modify learning rate with special warm up BERT uses
|
# modify learning rate with special warm up BERT uses
|
||||||
@@ -318,10 +314,19 @@ def main():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
|
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (global_step + 1) % args.log_every == 0):
|
||||||
if not args.fp16:
|
learning_rate = optimizer.get_lr()[0] if not args.fp16 else lr_this_step
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
logger.info("[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f}".format(
|
||||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
global_step, gnorm, learning_rate, curr_tr_loss / curr_steps))
|
||||||
|
tb_writer.add_scalar('lr', learning_rate, global_step)
|
||||||
|
tb_writer.add_scalar('loss', curr_tr_loss / curr_steps, global_step)
|
||||||
|
curr_tr_loss, curr_steps = 0., 1
|
||||||
|
|
||||||
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
### Example:
|
### Example:
|
||||||
@@ -435,7 +440,7 @@ def main():
|
|||||||
preds = np.squeeze(preds)
|
preds = np.squeeze(preds)
|
||||||
result = compute_metrics(task_name, preds, out_label_ids)
|
result = compute_metrics(task_name, preds, out_label_ids)
|
||||||
|
|
||||||
loss = tr_loss/global_step if args.do_train else None
|
loss = curr_tr_loss/curr_steps if args.do_train else None
|
||||||
|
|
||||||
result['eval_loss'] = eval_loss
|
result['eval_loss'] = eval_loss
|
||||||
result['global_step'] = global_step
|
result['global_step'] = global_step
|
||||||
@@ -508,7 +513,7 @@ def main():
|
|||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
result = compute_metrics(task_name, preds, out_label_ids)
|
result = compute_metrics(task_name, preds, out_label_ids)
|
||||||
|
|
||||||
loss = tr_loss/global_step if args.do_train else None
|
loss = curr_tr_loss/curr_steps if args.do_train else None
|
||||||
|
|
||||||
result['eval_loss'] = eval_loss
|
result['eval_loss'] = eval_loss
|
||||||
result['global_step'] = global_step
|
result['global_step'] = global_step
|
||||||
|
|||||||
@@ -270,15 +270,13 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
class BertSelfAttention(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertSelfAttention, self).__init__()
|
super(BertSelfAttention, self).__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.keep_multihead_output = keep_multihead_output
|
|
||||||
self.multihead_output = None
|
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
@@ -329,9 +327,9 @@ class BertSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
if self.output_attentions:
|
|
||||||
return attention_probs, context_layer
|
outputs = [context_layer, attention_probs] if self.output_attentions else [context_layer]
|
||||||
return context_layer
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class BertSelfOutput(nn.Module):
|
class BertSelfOutput(nn.Module):
|
||||||
@@ -349,11 +347,10 @@ class BertSelfOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertAttention(nn.Module):
|
class BertAttention(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertAttention, self).__init__()
|
super(BertAttention, self).__init__()
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.self = BertSelfAttention(config, output_attentions=output_attentions,
|
self.self = BertSelfAttention(config, output_attentions=output_attentions)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
@@ -374,13 +371,10 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
|
||||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||||
self_output = self.self(input_tensor, attention_mask, head_mask)
|
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||||
if self.output_attentions:
|
attention_output = self.output(self_outputs[0], input_tensor)
|
||||||
attentions, self_output = self_output
|
outputs = [attention_output] + self_outputs[1:] # add attentions if we output them
|
||||||
attention_output = self.output(self_output, input_tensor)
|
return outputs
|
||||||
if self.output_attentions:
|
|
||||||
return attentions, attention_output
|
|
||||||
return attention_output
|
|
||||||
|
|
||||||
|
|
||||||
class BertIntermediate(nn.Module):
|
class BertIntermediate(nn.Module):
|
||||||
@@ -413,48 +407,52 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.attention = BertAttention(config, output_attentions=output_attentions,
|
self.attention = BertAttention(config, output_attentions=output_attentions)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, head_mask)
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
if self.output_attentions:
|
intermediate_output = self.intermediate(attention_outputs[0])
|
||||||
attentions, attention_output = attention_output
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
if self.output_attentions:
|
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them
|
||||||
return attentions, layer_output
|
return outputs
|
||||||
return layer_output
|
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
layer = BertLayer(config, output_attentions=output_attentions,
|
self.output_hidden_states = output_hidden_states
|
||||||
keep_multihead_output=keep_multihead_output)
|
layer = BertLayer(config, output_attentions=output_attentions)
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
|
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||||
all_encoder_layers = []
|
all_hidden_states = []
|
||||||
all_attentions = []
|
all_attentions = []
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
attentions, hidden_states = hidden_states
|
all_attentions.append(layer_outputs[1])
|
||||||
all_attentions.append(attentions)
|
|
||||||
if output_all_encoded_layers:
|
# Add last layer
|
||||||
all_encoder_layers.append(hidden_states)
|
if self.output_hidden_states:
|
||||||
if not output_all_encoded_layers:
|
all_hidden_states.append(hidden_states)
|
||||||
all_encoder_layers.append(hidden_states)
|
|
||||||
|
outputs = [hidden_states]
|
||||||
|
if self.output_hidden_states:
|
||||||
|
outputs.append(all_hidden_states)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
return all_attentions, all_encoder_layers
|
outputs.append(all_attentions)
|
||||||
return all_encoder_layers
|
return outputs # outputs, (hidden states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
class BertPooler(nn.Module):
|
||||||
@@ -617,12 +615,13 @@ class BertModel(BertPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertModel, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = BertEncoder(config, output_attentions=output_attentions,
|
self.encoder = BertEncoder(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
output_hidden_states=output_hidden_states)
|
||||||
self.pooler = BertPooler(config)
|
self.pooler = BertPooler(config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
@@ -633,13 +632,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def get_multihead_outputs(self):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None):
|
||||||
""" Gather all multi-head outputs.
|
|
||||||
Return: list (layers) of multihead module outputs with gradients
|
|
||||||
"""
|
|
||||||
return [layer.attention.self.multihead_output for layer in self.encoder.layer]
|
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None):
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
@@ -676,19 +669,14 @@ class BertModel(BertPreTrainedModel):
|
|||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||||
encoded_layers = self.encoder(embedding_output,
|
encoder_outputs = self.encoder(embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
output_all_encoded_layers=output_all_encoded_layers,
|
head_mask=head_mask)
|
||||||
head_mask=head_mask)
|
sequence_output = encoder_outputs[0]
|
||||||
if self.output_attentions:
|
|
||||||
all_attentions, encoded_layers = encoded_layers
|
|
||||||
sequence_output = encoded_layers[-1]
|
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
if not output_all_encoded_layers:
|
|
||||||
encoded_layers = encoded_layers[-1]
|
outputs = [sequence_output, pooled_output] + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||||
if self.output_attentions:
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
return all_attentions, encoded_layers, pooled_output
|
|
||||||
return encoded_layers, pooled_output
|
|
||||||
|
|
||||||
|
|
||||||
class BertForPreTraining(BertPreTrainedModel):
|
class BertForPreTraining(BertPreTrainedModel):
|
||||||
@@ -746,32 +734,33 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(BertForPreTraining, self).__init__(config)
|
super(BertForPreTraining, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.bert = BertModel(config, output_attentions=output_attentions,
|
self.bert = BertModel(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
output_hidden_states=output_hidden_states)
|
||||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
next_sentence_label=None, head_mask=None):
|
||||||
output_all_encoded_layers=False, head_mask=head_mask)
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
|
||||||
if self.output_attentions:
|
|
||||||
all_attentions, sequence_output, pooled_output = outputs
|
sequence_output, pooled_output = outputs[:2]
|
||||||
else:
|
|
||||||
sequence_output, pooled_output = outputs
|
|
||||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||||
|
|
||||||
|
outputs = [prediction_scores, seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
total_loss = masked_lm_loss + next_sentence_loss
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
return total_loss
|
outputs = [total_loss] + outputs
|
||||||
elif self.output_attentions:
|
|
||||||
return all_attentions, prediction_scores, seq_relationship_score
|
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||||
return prediction_scores, seq_relationship_score
|
|
||||||
|
|
||||||
|
|
||||||
class BertForMaskedLM(BertPreTrainedModel):
|
class BertForMaskedLM(BertPreTrainedModel):
|
||||||
|
|||||||
@@ -919,9 +919,11 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class XLMModel(XLMPreTrainedModel):
|
class XLMModel(XLMPreTrainedModel):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLMModel, self).__init__(config)
|
super(XLMModel, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.mem_len = config.mem_len
|
self.mem_len = config.mem_len
|
||||||
self.reuse_len = config.reuse_len
|
self.reuse_len = config.reuse_len
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@@ -1038,8 +1040,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
|
||||||
output_all_encoded_layers=True, head_mask=None):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
@@ -1188,23 +1189,45 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
mems = [None] * len(self.layer)
|
mems = [None] * len(self.layer)
|
||||||
|
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
|
attentions = []
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
# cache new mems
|
# cache new mems
|
||||||
new_mems.append(self.cache_mem(output_h, mems[i]))
|
new_mems.append(self.cache_mem(output_h, mems[i]))
|
||||||
|
# Save hidden_states
|
||||||
|
if output_g is None:
|
||||||
|
hidden_states.append(output_h)
|
||||||
|
else:
|
||||||
|
hidden_states.append((output_h, output_g))
|
||||||
|
|
||||||
output_h, output_g = layer_module(output_h, output_g,
|
output_h, output_g = layer_module(output_h, output_g,
|
||||||
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
|
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
|
||||||
r=pos_emb, seg_mat=seg_mat,
|
r=pos_emb, seg_mat=seg_mat,
|
||||||
mems=mems[i], target_mapping=target_mapping,
|
mems=mems[i], target_mapping=target_mapping,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
|
# Save last hidden_state
|
||||||
|
if output_g is None:
|
||||||
hidden_states.append(output_h)
|
hidden_states.append(output_h)
|
||||||
|
else:
|
||||||
|
hidden_states.append((output_h, output_g))
|
||||||
|
|
||||||
|
# Select the right output and add dropout
|
||||||
output = self.dropout(output_g if output_g is not None else output_h)
|
output = self.dropout(output_g if output_g is not None else output_h)
|
||||||
|
|
||||||
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||||
output = output.permute(1, 0, 2).contiguous()
|
output = output.permute(1, 0, 2).contiguous()
|
||||||
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
if output_g is None:
|
||||||
|
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
||||||
|
else:
|
||||||
|
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
|
||||||
|
|
||||||
return output, hidden_states, new_mems
|
# Build the list of outputs
|
||||||
|
outputs = [output, new_mems]
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs.append(attentions)
|
||||||
|
if self.output_hidden_states:
|
||||||
|
outputs.append(hidden_states)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class XLMPredLayer(nn.Module):
|
class XLMPredLayer(nn.Module):
|
||||||
@@ -1309,14 +1332,15 @@ class XLMLMHeadModel(XLMPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLMLMHeadModel, self).__init__(config)
|
super(XLMLMHeadModel, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
|
|
||||||
self.transformer = XLMModel(config, output_attentions=output_attentions,
|
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||||
|
|
||||||
# Tie weights
|
# Tie weights
|
||||||
@@ -1331,7 +1355,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
labels=None, output_all_encoded_layers=True, head_mask=None):
|
labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
@@ -1358,33 +1382,28 @@ class XLMLMHeadModel(XLMPreTrainedModel):
|
|||||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||||
to pool the input to get a vector representation.
|
to pool the input to get a vector representation.
|
||||||
"""
|
"""
|
||||||
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q,
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
output_all_encoded_layers, head_mask)
|
|
||||||
|
|
||||||
|
output = transformer_outputs[0]
|
||||||
logits = self.lm_loss(output)
|
logits = self.lm_loss(output)
|
||||||
|
|
||||||
|
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
||||||
labels.view(-1))
|
labels.view(-1))
|
||||||
return loss, new_mems
|
outputs = [loss] + outputs
|
||||||
|
|
||||||
# if self.output_attentions:
|
outputs = [logits] + outputs
|
||||||
# all_attentions, encoded_layers = encoded_layers
|
|
||||||
# sequence_output = encoded_layers[-1]
|
return outputs
|
||||||
# pooled_output = self.pooler(sequence_output)
|
|
||||||
# if not output_all_encoded_layers:
|
|
||||||
# encoded_layers = encoded_layers[-1]
|
|
||||||
# if self.output_attentions:
|
|
||||||
return logits, new_mems
|
|
||||||
# return all_attentions, encoded_layers, pooled_output
|
|
||||||
|
|
||||||
|
|
||||||
class XLMSequenceSummary(nn.Module):
|
class XLMSequenceSummary(nn.Module):
|
||||||
def __init__(self, config, summary_type="last", use_proj=True,
|
def __init__(self, config, summary_type="last", use_proj=True):
|
||||||
output_attentions=False, keep_multihead_output=False):
|
|
||||||
super(XLMSequenceSummary, self).__init__()
|
super(XLMSequenceSummary, self).__init__()
|
||||||
self.summary_type = summary_type
|
self.summary_type = summary_type
|
||||||
if use_proj:
|
if use_proj:
|
||||||
@@ -1481,26 +1500,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
||||||
output_attentions=False, keep_multihead_output=False):
|
output_attentions=False, output_hidden_states=False):
|
||||||
super(XLMForSequenceClassification, self).__init__(config)
|
super(XLMForSequenceClassification, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.attn_type = config.attn_type
|
self.output_hidden_states = output_hidden_states
|
||||||
self.same_length = config.same_length
|
|
||||||
self.summary_type = summary_type
|
self.summary_type = summary_type
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
|
||||||
self.transformer = XLMModel(config, output_attentions=output_attentions,
|
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
|
|
||||||
self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type,
|
self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type, use_proj=use_proj)
|
||||||
use_proj=use_proj, output_attentions=output_attentions,
|
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
labels=None, output_all_encoded_layers=True, head_mask=None):
|
labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
@@ -1528,13 +1544,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
Only used during pretraining for two-stream attention.
|
Only used during pretraining for two-stream attention.
|
||||||
Set to None during finetuning.
|
Set to None during finetuning.
|
||||||
"""
|
"""
|
||||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q,
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
output_all_encoded_layers, head_mask)
|
|
||||||
|
|
||||||
|
output = transformer_outputs[0]
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
logits = self.logits_proj(output)
|
logits = self.logits_proj(output)
|
||||||
|
|
||||||
|
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.num_labels == 1:
|
||||||
# We are doing regression
|
# We are doing regression
|
||||||
@@ -1543,17 +1561,11 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, new_mems
|
outputs = [loss] + outputs
|
||||||
|
|
||||||
# if self.output_attentions:
|
outputs = [logits] + outputs
|
||||||
# all_attentions, encoded_layers = encoded_layers
|
|
||||||
# sequence_output = encoded_layers[-1]
|
return outputs
|
||||||
# pooled_output = self.pooler(sequence_output)
|
|
||||||
# if not output_all_encoded_layers:
|
|
||||||
# encoded_layers = encoded_layers[-1]
|
|
||||||
# if self.output_attentions:
|
|
||||||
return logits, new_mems
|
|
||||||
# return all_attentions, encoded_layers, pooled_output
|
|
||||||
|
|
||||||
|
|
||||||
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||||
@@ -1612,27 +1624,30 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLMForQuestionAnswering, self).__init__(config)
|
super(XLMForQuestionAnswering, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.transformer = XLMModel(config, output_attentions=output_attentions,
|
self.output_hidden_states = output_hidden_states
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
|
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
start_positions=None, end_positions=None,
|
start_positions=None, end_positions=None, head_mask=None):
|
||||||
output_all_encoded_layers=True, head_mask=None):
|
|
||||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
|
||||||
mems, perm_mask, target_mapping, inp_q,
|
|
||||||
output_all_encoded_layers, head_mask)
|
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
|
|
||||||
|
output = transformer_outputs[0]
|
||||||
logits = self.qa_outputs(output)
|
logits = self.qa_outputs(output)
|
||||||
start_logits, end_logits = logits.split(1, dim=-1)
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||||
|
|
||||||
if start_positions is not None and end_positions is not None:
|
if start_positions is not None and end_positions is not None:
|
||||||
# If we are on multi-GPU, split add a dimension
|
# If we are on multi-GPU, split add a dimension
|
||||||
if len(start_positions.size()) > 1:
|
if len(start_positions.size()) > 1:
|
||||||
@@ -1648,7 +1663,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
end_loss = loss_fct(end_logits, end_positions)
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
total_loss = (start_loss + end_loss) / 2
|
total_loss = (start_loss + end_loss) / 2
|
||||||
return total_loss
|
outputs = [total_loss] + outputs
|
||||||
elif self.output_attentions:
|
|
||||||
return all_attentions, start_logits, end_logits
|
outputs = [start_logits, end_logits] + outputs
|
||||||
return start_logits, end_logits
|
|
||||||
|
return outputs
|
||||||
|
|||||||
@@ -323,16 +323,13 @@ except ImportError:
|
|||||||
return self.weight * x + self.bias
|
return self.weight * x + self.bias
|
||||||
|
|
||||||
class XLNetRelativeAttention(nn.Module):
|
class XLNetRelativeAttention(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(XLNetRelativeAttention, self).__init__()
|
super(XLNetRelativeAttention, self).__init__()
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
if config.d_model % config.n_head != 0:
|
if config.d_model % config.n_head != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
"heads (%d)" % (config.d_model, config.n_head))
|
"heads (%d)" % (config.d_model, config.n_head))
|
||||||
self.output_attentions = output_attentions
|
|
||||||
self.keep_multihead_output = keep_multihead_output
|
|
||||||
self.multihead_output = None
|
|
||||||
|
|
||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.d_head = config.d_head
|
self.d_head = config.d_head
|
||||||
@@ -368,7 +365,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None):
|
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None):
|
||||||
"""Core relative positional attention operations."""
|
"""Core relative positional attention operations."""
|
||||||
|
|
||||||
# content based attention score
|
# content based attention score
|
||||||
@@ -395,9 +392,16 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
attn_prob = F.softmax(attn_score, dim=1)
|
attn_prob = F.softmax(attn_score, dim=1)
|
||||||
attn_prob = self.dropout(attn_prob)
|
attn_prob = self.dropout(attn_prob)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attn_prob = attn_prob * head_mask
|
||||||
|
|
||||||
# attention output
|
# attention output
|
||||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
|
attn_vec = torch.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
return attn_vec, attn_prob
|
||||||
|
|
||||||
return attn_vec
|
return attn_vec
|
||||||
|
|
||||||
def post_attention(self, h, attn_vec, residual=True):
|
def post_attention(self, h, attn_vec, residual=True):
|
||||||
@@ -439,7 +443,10 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
# core attention ops
|
# core attention ops
|
||||||
attn_vec_h = self.rel_attn_core(
|
attn_vec_h = self.rel_attn_core(
|
||||||
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
|
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
attn_vec_h, attn_prob_h = attn_vec_h
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention(h, attn_vec_h)
|
output_h = self.post_attention(h, attn_vec_h)
|
||||||
@@ -452,14 +459,25 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
if target_mapping is not None:
|
if target_mapping is not None:
|
||||||
q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
|
q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
|
||||||
attn_vec_g = self.rel_attn_core(
|
attn_vec_g = self.rel_attn_core(
|
||||||
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g)
|
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
attn_vec_g, attn_prob_g = attn_vec_g
|
||||||
|
|
||||||
attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
|
attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
|
||||||
else:
|
else:
|
||||||
attn_vec_g = self.rel_attn_core(
|
attn_vec_g = self.rel_attn_core(
|
||||||
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g)
|
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
attn_vec_g, attn_prob_g = attn_vec_g
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_g = self.post_attention(g, attn_vec_g)
|
output_g = self.post_attention(g, attn_vec_g)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
attn_prob = attn_prob_h, attn_prob_g
|
||||||
|
|
||||||
else:
|
else:
|
||||||
###### Multi-head attention with relative positional encoding
|
###### Multi-head attention with relative positional encoding
|
||||||
if mems is not None and mems.dim() > 1:
|
if mems is not None and mems.dim() > 1:
|
||||||
@@ -477,30 +495,18 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
# core attention ops
|
# core attention ops
|
||||||
attn_vec = self.rel_attn_core(
|
attn_vec = self.rel_attn_core(
|
||||||
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
|
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask)
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
attn_vec, attn_prob = attn_vec
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention(h, attn_vec)
|
output_h = self.post_attention(h, attn_vec)
|
||||||
output_g = None
|
output_g = None
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
return output_h, output_g, attn_prob
|
||||||
|
|
||||||
# Mask heads if we want to
|
|
||||||
# if head_mask is not None:
|
|
||||||
# attention_probs = attention_probs * head_mask
|
|
||||||
|
|
||||||
# context_layer = torch.matmul(attention_probs, value_layer)
|
|
||||||
# if self.keep_multihead_output:
|
|
||||||
# self.multihead_output = context_layer
|
|
||||||
# self.multihead_output.retain_grad()
|
|
||||||
|
|
||||||
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
||||||
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
||||||
# context_layer = context_layer.view(*new_context_layer_shape)
|
|
||||||
|
|
||||||
# if self.output_attentions:
|
|
||||||
# attentions, self_output = self_output
|
|
||||||
# if self.output_attentions:
|
|
||||||
# return attentions, attention_output
|
|
||||||
return output_h, output_g
|
return output_h, output_g
|
||||||
|
|
||||||
class XLNetFeedForward(nn.Module):
|
class XLNetFeedForward(nn.Module):
|
||||||
@@ -510,7 +516,8 @@ class XLNetFeedForward(nn.Module):
|
|||||||
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
|
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
|
||||||
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
|
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
if isinstance(config.ff_activation, str) or (sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)):
|
if isinstance(config.ff_activation, str) or \
|
||||||
|
(sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)):
|
||||||
self.activation_function = ACT2FN[config.ff_activation]
|
self.activation_function = ACT2FN[config.ff_activation]
|
||||||
else:
|
else:
|
||||||
self.activation_function = config.ff_activation
|
self.activation_function = config.ff_activation
|
||||||
@@ -526,29 +533,27 @@ class XLNetFeedForward(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
class XLNetLayer(nn.Module):
|
class XLNetLayer(nn.Module):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, ):
|
||||||
super(XLNetLayer, self).__init__()
|
super(XLNetLayer, self).__init__()
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions,
|
self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.ff = XLNetFeedForward(config)
|
self.ff = XLNetFeedForward(config)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
def forward(self, output_h, output_g,
|
def forward(self, output_h, output_g,
|
||||||
attn_mask_h, attn_mask_g,
|
attn_mask_h, attn_mask_g,
|
||||||
r, seg_mat,
|
r, seg_mat, mems=None, target_mapping=None, head_mask=None):
|
||||||
mems=None, target_mapping=None, head_mask=None):
|
outputs = self.rel_attn(output_h, output_g, attn_mask_h, attn_mask_g,
|
||||||
output_h, output_g = self.rel_attn(output_h, output_g,
|
r, seg_mat, mems=mems, target_mapping=target_mapping,
|
||||||
attn_mask_h, attn_mask_g,
|
head_mask=head_mask)
|
||||||
r, seg_mat,
|
output_h, output_g = outputs[:2]
|
||||||
mems=mems, target_mapping=target_mapping, head_mask=head_mask)
|
|
||||||
if output_g is not None:
|
if output_g is not None:
|
||||||
output_g = self.ff(output_g)
|
output_g = self.ff(output_g)
|
||||||
output_h = self.ff(output_h)
|
output_h = self.ff(output_h)
|
||||||
|
|
||||||
# if self.output_attentions:
|
outputs = [output_h, output_g] + outputs[2:] # Add again attentions if there are there
|
||||||
# return attentions, layer_output
|
return outputs
|
||||||
return output_h, output_g
|
|
||||||
|
|
||||||
|
|
||||||
class XLNetPreTrainedModel(PreTrainedModel):
|
class XLNetPreTrainedModel(PreTrainedModel):
|
||||||
@@ -584,9 +589,11 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class XLNetModel(XLNetPreTrainedModel):
|
class XLNetModel(XLNetPreTrainedModel):
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLNetModel, self).__init__(config)
|
super(XLNetModel, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.mem_len = config.mem_len
|
self.mem_len = config.mem_len
|
||||||
self.reuse_len = config.reuse_len
|
self.reuse_len = config.reuse_len
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@@ -597,8 +604,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
|
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
|
||||||
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
|
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
|
||||||
layer = XLNetLayer(config, output_attentions=output_attentions,
|
layer = XLNetLayer(config, output_attentions=output_attentions)
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
@@ -851,28 +857,39 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
if mems is None:
|
if mems is None:
|
||||||
mems = [None] * len(self.layer)
|
mems = [None] * len(self.layer)
|
||||||
|
|
||||||
|
attentions = []
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
# cache new mems
|
# cache new mems
|
||||||
new_mems.append(self.cache_mem(output_h, mems[i]))
|
new_mems.append(self.cache_mem(output_h, mems[i]))
|
||||||
|
if self.output_hidden_states:
|
||||||
|
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||||
|
|
||||||
|
outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
|
||||||
|
r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping,
|
||||||
|
head_mask=head_mask)
|
||||||
|
output_h, output_g = outputs[:2]
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions.append(outputs[2:])
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
if self.output_hidden_states:
|
||||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||||
|
|
||||||
output_h, output_g = layer_module(output_h, output_g,
|
|
||||||
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
|
|
||||||
r=pos_emb, seg_mat=seg_mat,
|
|
||||||
mems=mems[i], target_mapping=target_mapping,
|
|
||||||
head_mask=head_mask)
|
|
||||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
|
||||||
output = self.dropout(output_g if output_g is not None else output_h)
|
output = self.dropout(output_g if output_g is not None else output_h)
|
||||||
|
|
||||||
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||||
output = output.permute(1, 0, 2).contiguous()
|
outputs = [output.permute(1, 0, 2).contiguous(), new_mems]
|
||||||
if output_g is not None:
|
if self.output_hidden_states:
|
||||||
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
|
if output_g is not None:
|
||||||
else:
|
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
|
||||||
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
else:
|
||||||
|
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
||||||
|
outputs.append(hidden_states)
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs.append(attentions)
|
||||||
|
|
||||||
return output, hidden_states, new_mems
|
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
class XLNetLMHeadModel(XLNetPreTrainedModel):
|
class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||||
@@ -936,14 +953,16 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLNetLMHeadModel, self).__init__(config)
|
super(XLNetLMHeadModel, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
|
|
||||||
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
output_hidden_states=output_hidden_states)
|
||||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||||
|
|
||||||
# Tie weights
|
# Tie weights
|
||||||
@@ -989,27 +1008,24 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||||
to pool the input to get a vector representation.
|
to pool the input to get a vector representation.
|
||||||
"""
|
"""
|
||||||
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
|
|
||||||
logits = self.lm_loss(output)
|
logits = self.lm_loss(transformer_outputs[0])
|
||||||
|
|
||||||
|
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
||||||
labels.view(-1))
|
labels.view(-1))
|
||||||
return loss, new_mems
|
outputs = [loss] + outputs
|
||||||
|
|
||||||
# if self.output_attentions:
|
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||||
# all_attentions, encoded_layers = encoded_layers
|
|
||||||
# if self.output_attentions:
|
|
||||||
return logits, new_mems
|
|
||||||
# return all_attentions, encoded_layers, pooled_output
|
|
||||||
|
|
||||||
class XLNetSequenceSummary(nn.Module):
|
class XLNetSequenceSummary(nn.Module):
|
||||||
def __init__(self, config, summary_type="last", use_proj=True,
|
def __init__(self, config, summary_type="last", use_proj=True):
|
||||||
output_attentions=False, keep_multihead_output=False):
|
|
||||||
super(XLNetSequenceSummary, self).__init__()
|
super(XLNetSequenceSummary, self).__init__()
|
||||||
self.summary_type = summary_type
|
self.summary_type = summary_type
|
||||||
if use_proj:
|
if use_proj:
|
||||||
@@ -1106,20 +1122,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
|
||||||
output_attentions=False, keep_multihead_output=False):
|
output_attentions=False, output_hidden_states=False):
|
||||||
super(XLNetForSequenceClassification, self).__init__(config)
|
super(XLNetForSequenceClassification, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
self.summary_type = summary_type
|
self.summary_type = summary_type
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
|
||||||
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
output_hidden_states=output_hidden_states)
|
||||||
|
|
||||||
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
|
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, use_proj=use_proj)
|
||||||
use_proj=use_proj, output_attentions=output_attentions,
|
|
||||||
keep_multihead_output=keep_multihead_output)
|
|
||||||
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
@@ -1153,12 +1169,15 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
Only used during pretraining for two-stream attention.
|
Only used during pretraining for two-stream attention.
|
||||||
Set to None during finetuning.
|
Set to None during finetuning.
|
||||||
"""
|
"""
|
||||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
logits = self.logits_proj(output)
|
logits = self.logits_proj(output)
|
||||||
|
|
||||||
|
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.num_labels == 1:
|
if self.num_labels == 1:
|
||||||
# We are doing regression
|
# We are doing regression
|
||||||
@@ -1167,13 +1186,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, new_mems
|
outputs = [loss] + outputs
|
||||||
|
|
||||||
|
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||||
|
|
||||||
# if self.output_attentions:
|
|
||||||
# all_attentions, encoded_layers = encoded_layers
|
|
||||||
# if self.output_attentions:
|
|
||||||
return logits, new_mems
|
|
||||||
# return all_attentions, encoded_layers, pooled_output
|
|
||||||
|
|
||||||
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||||
"""XLNet model for Question Answering (span extraction).
|
"""XLNet model for Question Answering (span extraction).
|
||||||
@@ -1231,25 +1247,30 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
def __init__(self, config, output_attentions=False, output_hidden_states=False):
|
||||||
super(XLNetForQuestionAnswering, self).__init__(config)
|
super(XLNetForQuestionAnswering, self).__init__(config)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = output_attentions
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
|
||||||
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
self.transformer = XLNetModel(config, output_attentions=output_attentions,
|
||||||
keep_multihead_output=keep_multihead_output)
|
output_hidden_states=output_hidden_states)
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
start_positions=None, end_positions=None, head_mask=None):
|
start_positions=None, end_positions=None, head_mask=None):
|
||||||
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, inp_q, head_mask)
|
||||||
|
|
||||||
logits = self.qa_outputs(output)
|
logits = self.qa_outputs(transformer_outputs[0])
|
||||||
|
|
||||||
start_logits, end_logits = logits.split(1, dim=-1)
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = [start_logits, end_logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
if start_positions is not None and end_positions is not None:
|
if start_positions is not None and end_positions is not None:
|
||||||
# If we are on multi-GPU, split add a dimension
|
# If we are on multi-GPU, split add a dimension
|
||||||
if len(start_positions.size()) > 1:
|
if len(start_positions.size()) > 1:
|
||||||
@@ -1265,7 +1286,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
start_loss = loss_fct(start_logits, start_positions)
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
end_loss = loss_fct(end_logits, end_positions)
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
total_loss = (start_loss + end_loss) / 2
|
total_loss = (start_loss + end_loss) / 2
|
||||||
return total_loss
|
outputs = [total_loss] + outputs
|
||||||
elif self.output_attentions:
|
|
||||||
return all_attentions, start_logits, end_logits
|
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||||
return start_logits, end_logits
|
|
||||||
|
|||||||
Reference in New Issue
Block a user