New run glue script (#7917)
* Start simplification * More progress * Finished script * Address comments and update tests instructions * Wrong test * Accept files as inputs and fix test * Update src/transformers/trainer_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Fix labels and add combined score * Add special labels * Update TPU command * Revert to old label strategy * Use model labels * Fix for STT-B * Styling * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Code styling * Fix review comments Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -854,8 +854,6 @@ class Trainer:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial, metrics=metrics)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
@@ -1173,7 +1171,7 @@ class Trainer:
|
||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None):
|
||||
@@ -1188,7 +1186,7 @@ class Trainer:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
@@ -1272,6 +1270,7 @@ class Trainer:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||
return output.metrics
|
||||
|
||||
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
||||
|
||||
Reference in New Issue
Block a user