Trainer multi label (#7191)
* Trainer accep multiple labels * Missing import * Fix dosctrings
This commit is contained in:
@@ -2,7 +2,7 @@ import dataclasses
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .utils import logging
|
||||
@@ -128,6 +128,12 @@ class TrainingArguments:
|
||||
forward method.
|
||||
|
||||
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
|
||||
label_names (:obj:`List[str]`, `optional`):
|
||||
The list of keys in your dictionary of inputs that correspond to the labels.
|
||||
|
||||
Will eventually default to :obj:`["labels"]` except if the model used is one of the
|
||||
:obj:`XxxForQuestionAnswering` in which case it will default to
|
||||
:obj:`["start_positions", "end_positions"]`.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@@ -253,13 +259,16 @@ class TrainingArguments:
|
||||
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
||||
)
|
||||
label_names: Optional[List[str]] = field(
|
||||
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user