adding tests

This commit is contained in:
thomwolf
2019-08-05 18:14:07 +02:00
parent b90e29d52c
commit ed4e542260
4 changed files with 83 additions and 3 deletions

View File

@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
def __init__(self, base_model):
super(AutoModelWithLMHead, self).__init__(base_model)
config = base_model.config
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.apply(self.init_weights)
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = {
'num_labels': 2,
'summary_type': 'first',
'summary_use_proj': True,
'summary_activation': None,
'summary_proj_to_labels': True,
'summary_first_dropout': 0.1
}
class AutoModelForSequenceClassification(DerivedAutoModel):
r"""
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
def __init__(self, base_model):
super(AutoModelForSequenceClassification, self).__init__(base_model)
self.num_labels = base_model.config.num_labels
self.sequence_summary = SequenceSummary(base_model.config)
# Complete configuration with defaults if necessary
config = base_model.config
for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items():
if not hasattr(config, key):
setattr(config, key, value)
# Update base model and derived model config
self.transformer.config = config
self.config = config
self.num_labels = config.num_labels
self.sequence_summary = SequenceSummary(config)
self.apply(self.init_weights)