Switch from return_tuple to return_dict (#6138)

* Switch from return_tuple to return_dict

* Fix test

* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)

* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests

* AutoModels


Tiny tweaks

* Style

* Final changes before merge

* Re-order for simpler review

* Final fixes

* Addressing @sgugger's comments

* Test MultipleChoice

* Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import

* Switch from return_tuple to return_dict

* Fix test

* Add recent model

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-07-30 09:17:00 -04:00
committed by GitHub
parent 562b6369c4
commit 91cb95461e
35 changed files with 678 additions and 636 deletions

View File

@@ -1167,7 +1167,7 @@ class SQuADHead(nn.Module):
cls_index: Optional[torch.LongTensor] = None,
is_impossible: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
return_tuple: bool = False,
return_dict: bool = False,
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
"""
Args:
@@ -1184,8 +1184,8 @@ class SQuADHead(nn.Module):
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a plain tuple instead of a :class:`~transformers.file_utils.ModelOuput`.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
Returns:
"""
@@ -1214,7 +1214,7 @@ class SQuADHead(nn.Module):
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
return (total_loss,) if return_tuple else SquadHeadOutput(loss=total_loss)
return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
else:
# during inference, compute the end logits based on beam search
@@ -1244,7 +1244,7 @@ class SQuADHead(nn.Module):
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
if return_tuple:
if not return_dict:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
else:
return SquadHeadOutput(