New run glue script (#7917)

* Start simplification

* More progress

* Finished script

* Address comments and update tests instructions

* Wrong test

* Accept files as inputs and fix test

* Update src/transformers/trainer_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Fix labels and add combined score

* Add special labels

* Update TPU command

* Revert to old label strategy

* Use model labels

* Fix for STT-B

* Styling

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Code styling

* Fix review comments

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-10-22 11:42:22 -04:00
committed by GitHub
parent 18ce6b8ff3
commit 2e5052d4f1
8 changed files with 331 additions and 170 deletions

View File

@@ -854,8 +854,6 @@ class Trainer:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
@@ -1173,7 +1171,7 @@ class Trainer:
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
@@ -1188,7 +1186,7 @@ class Trainer:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
@@ -1272,6 +1270,7 @@ class Trainer:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput: