Add predict step accumulation (#7767)

* Add eval_accumulation_step and clean distributed eval

* Add TPU test

* Add TPU stuff

* Fix arg name

* Fix Seq2SeqTrainer

* Fix total_size

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Doc and add test to TPU

* Add unit test

* Adapt name

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-10-14 11:41:45 -04:00
committed by GitHub
parent 8feb0cc967
commit a1d1b332d0
10 changed files with 413 additions and 47 deletions

View File

@@ -174,7 +174,7 @@ class Seq2SeqTrainer(Trainer):
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().item()
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

View File

@@ -81,3 +81,14 @@ class TorchXLAExamplesTests(unittest.TestCase):
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
self.assertLess(end - start, 300)
def test_trainer_tpu(self):
import xla_spawn
testargs = """
transformers/tests/test_trainer_tpu.py
--num_cores=8
transformers/tests/test_trainer_tpu.py
""".split()
with patch.object(sys, "argv", testargs):
xla_spawn.main()