Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger
2020-12-18 15:10:39 -05:00
committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
6 changed files with 76 additions and 49 deletions

View File

@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import os
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@@ -411,7 +410,16 @@ class TrainingArguments:
self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
self_as_dict = asdict(self)
del self_as_dict["per_gpu_train_batch_size"]
del self_as_dict["per_gpu_eval_batch_size"]
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
@property
def train_batch_size(self) -> int:
@@ -523,7 +531,7 @@ class TrainingArguments:
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
d = dataclasses.asdict(self)
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value