run_squad WIP

This commit is contained in:
thomwolf
2018-11-02 03:56:14 +01:00
parent c0065af6cb
commit e61db0d1c0
3 changed files with 75 additions and 40 deletions

View File

@@ -441,8 +441,8 @@ class BertForSequenceClassification(nn.Module):
class BertForQuestionAnswering(nn.Module):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with linear layers on top of
the sequence output.
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Example usage:
```python
@@ -455,7 +455,7 @@ class BertForQuestionAnswering(nn.Module):
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForQuestionAnswering(config)
logits = model(input_ids, token_type_ids, input_mask)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):