adding tests
This commit is contained in:
@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
|
||||
|
||||
def __init__(self, base_model):
|
||||
super(AutoModelWithLMHead, self).__init__(base_model)
|
||||
config = base_model.config
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
|
||||
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = {
|
||||
'num_labels': 2,
|
||||
'summary_type': 'first',
|
||||
'summary_use_proj': True,
|
||||
'summary_activation': None,
|
||||
'summary_proj_to_labels': True,
|
||||
'summary_first_dropout': 0.1
|
||||
}
|
||||
|
||||
|
||||
|
||||
class AutoModelForSequenceClassification(DerivedAutoModel):
|
||||
r"""
|
||||
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
|
||||
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
|
||||
|
||||
def __init__(self, base_model):
|
||||
super(AutoModelForSequenceClassification, self).__init__(base_model)
|
||||
self.num_labels = base_model.config.num_labels
|
||||
self.sequence_summary = SequenceSummary(base_model.config)
|
||||
# Complete configuration with defaults if necessary
|
||||
config = base_model.config
|
||||
for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items():
|
||||
if not hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Update base model and derived model config
|
||||
self.transformer.config = config
|
||||
self.config = config
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user