Fixes to run_seq2seq and instructions (#9734)
* Fixes to run_seq2seq and instructions * Add more defaults for summarization
This commit is contained in:
@@ -136,10 +136,10 @@ class DataTrainingArguments:
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=142,
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
@@ -175,6 +175,9 @@ class DataTrainingArguments:
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@@ -190,11 +193,22 @@ class DataTrainingArguments:
|
||||
raise ValueError(
|
||||
"`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
|
||||
)
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
summarization_name_mapping = {
|
||||
"amazon_reviews_multi": ("review_body", "review_title"),
|
||||
"big_patent": ("description", "abstract"),
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"orange_sum": ("text", "summary"),
|
||||
"pn_summary": ("article", "summary"),
|
||||
"psc": ("extract_text", "summary_text"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"thaisum": ("body", "summary"),
|
||||
"xglue": ("news_body", "news_title"),
|
||||
"xsum": ("document", "summary"),
|
||||
"wiki_summary": ("article", "highlights"),
|
||||
}
|
||||
|
||||
|
||||
@@ -302,6 +316,16 @@ def main():
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
# Get the default prefix if None is passed.
|
||||
if data_args.source_prefix is None:
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
prefix = task_specific_params.get("prefix", "")
|
||||
else:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix = data_args.source_prefix
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
@@ -362,6 +386,7 @@ def main():
|
||||
else:
|
||||
inputs = examples[text_column]
|
||||
targets = examples[summary_column]
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
|
||||
Reference in New Issue
Block a user