New run_seq2seq script (#9605)
* New run_seq2seq script * Add tests * Mark as slow * Update examples/seq2seq/run_seq2seq.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/data/data_collator.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update src/transformers/data/data_collator.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Address review comments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -324,6 +324,7 @@ if is_torch_available():
|
||||
"DataCollator",
|
||||
"DataCollatorForLanguageModeling",
|
||||
"DataCollatorForPermutationLanguageModeling",
|
||||
"DataCollatorForSeq2Seq",
|
||||
"DataCollatorForSOP",
|
||||
"DataCollatorForTokenClassification",
|
||||
"DataCollatorForWholeWordMask",
|
||||
@@ -1395,6 +1396,7 @@ if TYPE_CHECKING:
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
|
||||
@@ -224,6 +224,63 @@ def tolist(x: Union[List[Any], torch.Tensor]):
|
||||
return x.tolist() if isinstance(x, torch.Tensor) else x
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received, as well as the labels.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence is provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
|
||||
def __call__(self, features):
|
||||
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
|
||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||
# same length to return tensors.
|
||||
if labels is not None:
|
||||
max_label_length = max(len(l) for l in labels)
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
|
||||
feature["labels"] = (
|
||||
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
||||
)
|
||||
|
||||
return self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForLanguageModeling:
|
||||
"""
|
||||
|
||||
@@ -35,6 +35,11 @@ class DataCollatorForPermutationLanguageModeling:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class DataCollatorForSeq2Seq:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class DataCollatorForSOP:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
Reference in New Issue
Block a user