run_squad WIP
This commit is contained in:
@@ -441,8 +441,8 @@ class BertForSequenceClassification(nn.Module):
|
||||
|
||||
class BertForQuestionAnswering(nn.Module):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with linear layers on top of
|
||||
the sequence output.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@@ -455,7 +455,7 @@ class BertForQuestionAnswering(nn.Module):
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = BertForQuestionAnswering(config)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
|
||||
Reference in New Issue
Block a user