Switch from return_tuple to return_dict (#6138)
* Switch from return_tuple to return_dict
* Fix test
* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)
* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests
* AutoModels
Tiny tweaks
* Style
* Final changes before merge
* Re-order for simpler review
* Final fixes
* Addressing @sgugger's comments
* Test MultipleChoice
* Rework TF trainer (#6038)
* Fully rework training/prediction loops
* fix method name
* Fix variable name
* Fix property name
* Fix scope
* Fix method name
* Fix tuple index
* Fix tuple index
* Fix indentation
* Fix variable name
* fix eval before log
* Add drop remainder for test dataset
* Fix step number + fix logging datetime
* fix eval loss value
* use global step instead of step + fix logging at step 0
* Fix logging datetime
* Fix global_step usage
* Fix breaking loop + logging datetime
* Fix step in prediction loop
* Fix step breaking
* Fix train/test loops
* Force TF at least 2.2 for the trainer
* Use assert_cardinality to facilitate the dataset size computation
* Log steps per epoch
* Make tfds compliant with TPU
* Make tfds compliant with TPU
* Use TF dataset enumerate instead of the Python one
* revert previous commit
* Fix data_dir
* Apply style
* rebase on master
* Address Sylvain's comments
* Address Sylvain's and Lysandre comments
* Trigger CI
* Remove unused import
* Switch from return_tuple to return_dict
* Fix test
* Add recent model
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
@@ -199,9 +199,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||
)
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
inputs["return_tuple"] = True
|
||||
|
||||
outputs = model(**inputs)
|
||||
# model outputs are always tuple in transformers (see doc)
|
||||
loss = outputs[0]
|
||||
@@ -316,8 +313,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
inputs.update(
|
||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||
)
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
inputs["return_tuple"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, feature_index in enumerate(feature_indices):
|
||||
|
||||
@@ -144,7 +144,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
|
||||
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
|
||||
Reference in New Issue
Block a user