Switch from return_tuple to return_dict (#6138)
* Switch from return_tuple to return_dict
* Fix test
* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)
* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests
* AutoModels
Tiny tweaks
* Style
* Final changes before merge
* Re-order for simpler review
* Final fixes
* Addressing @sgugger's comments
* Test MultipleChoice
* Rework TF trainer (#6038)
* Fully rework training/prediction loops
* fix method name
* Fix variable name
* Fix property name
* Fix scope
* Fix method name
* Fix tuple index
* Fix tuple index
* Fix indentation
* Fix variable name
* fix eval before log
* Add drop remainder for test dataset
* Fix step number + fix logging datetime
* fix eval loss value
* use global step instead of step + fix logging at step 0
* Fix logging datetime
* Fix global_step usage
* Fix breaking loop + logging datetime
* Fix step in prediction loop
* Fix step breaking
* Fix train/test loops
* Force TF at least 2.2 for the trainer
* Use assert_cardinality to facilitate the dataset size computation
* Log steps per epoch
* Make tfds compliant with TPU
* Make tfds compliant with TPU
* Use TF dataset enumerate instead of the Python one
* revert previous commit
* Fix data_dir
* Apply style
* rebase on master
* Address Sylvain's comments
* Address Sylvain's and Lysandre comments
* Trigger CI
* Remove unused import
* Switch from return_tuple to return_dict
* Fix test
* Add recent model
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
@@ -1167,7 +1167,7 @@ class SQuADHead(nn.Module):
|
||||
cls_index: Optional[torch.LongTensor] = None,
|
||||
is_impossible: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
return_tuple: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -1184,8 +1184,8 @@ class SQuADHead(nn.Module):
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return a plain tuple instead of a :class:`~transformers.file_utils.ModelOuput`.
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
@@ -1214,7 +1214,7 @@ class SQuADHead(nn.Module):
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
|
||||
return (total_loss,) if return_tuple else SquadHeadOutput(loss=total_loss)
|
||||
return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -1244,7 +1244,7 @@ class SQuADHead(nn.Module):
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
||||
else:
|
||||
return SquadHeadOutput(
|
||||
|
||||
Reference in New Issue
Block a user