Add timing inside Trainer (#9196)
* Add timing inside Trainer * Fix tests * Add n_objs for train * Sort logs
This commit is contained in:
@@ -12,10 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -411,7 +410,16 @@ class TrainingArguments:
|
||||
self.run_name = self.output_dir
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
|
||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||
self_as_dict = asdict(self)
|
||||
del self_as_dict["per_gpu_train_batch_size"]
|
||||
del self_as_dict["per_gpu_eval_batch_size"]
|
||||
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
|
||||
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
@@ -523,7 +531,7 @@ class TrainingArguments:
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
d = asdict(self)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, Enum):
|
||||
d[k] = v.value
|
||||
|
||||
Reference in New Issue
Block a user