Trainer multi label (#7191)

* Trainer accep multiple labels

* Missing import

* Fix dosctrings
This commit is contained in:
Sylvain Gugger
2020-09-17 08:15:37 -04:00
committed by GitHub
parent 709745927b
commit 492bb6aa48
4 changed files with 110 additions and 29 deletions

View File

@@ -2,7 +2,7 @@ import dataclasses
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .utils import logging
@@ -128,6 +128,12 @@ class TrainingArguments:
forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
label_names (:obj:`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
"""
output_dir: str = field(
@@ -253,13 +259,16 @@ class TrainingArguments:
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
)
label_names: Optional[List[str]] = field(
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
@property
def train_batch_size(self) -> int: