Switch from return_tuple to return_dict (#6138)
* Switch from return_tuple to return_dict
* Fix test
* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)
* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests
* AutoModels
Tiny tweaks
* Style
* Final changes before merge
* Re-order for simpler review
* Final fixes
* Addressing @sgugger's comments
* Test MultipleChoice
* Rework TF trainer (#6038)
* Fully rework training/prediction loops
* fix method name
* Fix variable name
* Fix property name
* Fix scope
* Fix method name
* Fix tuple index
* Fix tuple index
* Fix indentation
* Fix variable name
* fix eval before log
* Add drop remainder for test dataset
* Fix step number + fix logging datetime
* fix eval loss value
* use global step instead of step + fix logging at step 0
* Fix logging datetime
* Fix global_step usage
* Fix breaking loop + logging datetime
* Fix step in prediction loop
* Fix step breaking
* Fix train/test loops
* Force TF at least 2.2 for the trainer
* Use assert_cardinality to facilitate the dataset size computation
* Log steps per epoch
* Make tfds compliant with TPU
* Make tfds compliant with TPU
* Use TF dataset enumerate instead of the Python one
* revert previous commit
* Fix data_dir
* Apply style
* rebase on master
* Address Sylvain's comments
* Address Sylvain's and Lysandre comments
* Trigger CI
* Remove unused import
* Switch from return_tuple to return_dict
* Fix test
* Add recent model
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
@@ -550,7 +550,7 @@ class MobileBertEncoder(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_tuple=False,
|
||||
return_dict=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
@@ -575,7 +575,7 @@ class MobileBertEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
@@ -708,9 +708,9 @@ class MobileBertForPretrainingOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor]
|
||||
prediction_logits: torch.FloatTensor
|
||||
seq_relationship_logits: torch.FloatTensor
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
prediction_logits: torch.FloatTensor = None
|
||||
seq_relationship_logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@@ -773,8 +773,9 @@ MOBILEBERT_INPUTS_DOCSTRING = r"""
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
||||
plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@@ -831,13 +832,13 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
encoder_attention_mask=None,
|
||||
output_hidden_states=None,
|
||||
output_attentions=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@@ -890,12 +891,12 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
@@ -958,7 +959,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
next_sentence_label=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
@@ -979,7 +980,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
||||
>>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
|
||||
>>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased", return_dict=True)
|
||||
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
@@ -988,7 +989,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
||||
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -999,7 +1000,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
@@ -1011,7 +1012,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
@@ -1079,7 +1080,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -1097,7 +1098,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -1110,7 +1111,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1121,7 +1122,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
@@ -1169,7 +1170,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
next_sentence_label=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1186,7 +1187,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
|
||||
>>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased')
|
||||
>>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased', return_dict=True)
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
@@ -1196,7 +1197,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -1207,7 +1208,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1218,7 +1219,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (seq_relationship_score,) + outputs[2:]
|
||||
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||
|
||||
@@ -1263,7 +1264,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1272,7 +1273,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -1283,7 +1284,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
@@ -1299,7 +1300,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@@ -1342,7 +1343,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1354,7 +1355,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -1365,7 +1366,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1392,7 +1393,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
@@ -1438,7 +1439,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1446,7 +1447,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
@@ -1468,7 +1469,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1482,7 +1483,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@@ -1525,14 +1526,14 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
"""
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
@@ -1543,7 +1544,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1565,7 +1566,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user