Remove subclass for sortish sampler (#9907)
* Remove subclass for sortish sampler * Use old Seq2SeqTrainer in script * Styling
This commit is contained in:
@@ -20,6 +20,8 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
from seq2seq_trainer import Seq2SeqTrainer
|
||||||
|
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
@@ -27,8 +29,6 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
MBartTokenizerFast,
|
MBartTokenizerFast,
|
||||||
Seq2SeqTrainer,
|
|
||||||
Seq2SeqTrainingArguments,
|
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||||
@@ -286,6 +286,7 @@ def main():
|
|||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
data_args=data_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
data_collator=Seq2SeqDataCollator(
|
data_collator=Seq2SeqDataCollator(
|
||||||
@@ -323,9 +324,7 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
metrics = trainer.evaluate(
|
metrics = trainer.evaluate(metric_key_prefix="val")
|
||||||
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
|
|
||||||
)
|
|
||||||
metrics["val_n_objs"] = data_args.n_val
|
metrics["val_n_objs"] = data_args.n_val
|
||||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||||
|
|
||||||
@@ -337,12 +336,7 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
test_output = trainer.predict(
|
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
||||||
test_dataset=test_dataset,
|
|
||||||
metric_key_prefix="test",
|
|
||||||
max_length=data_args.val_max_target_length,
|
|
||||||
num_beams=data_args.eval_beams,
|
|
||||||
)
|
|
||||||
metrics = test_output.metrics
|
metrics = test_output.metrics
|
||||||
metrics["test_n_objs"] = data_args.n_test
|
metrics["test_n_objs"] = data_args.n_test
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
from .file_utils import is_torch_tpu_available
|
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .trainer_pt_utils import get_tpu_sampler
|
|
||||||
from .trainer_utils import PredictionOutput
|
from .trainer_utils import PredictionOutput
|
||||||
from .training_args import ParallelMode
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
|
||||||
return None
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
return get_tpu_sampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
if self.args.sortish_sampler:
|
|
||||||
self.train_dataset.make_sortish_sampler(
|
|
||||||
self.args.per_device_train_batch_size,
|
|
||||||
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
RandomSampler(self.train_dataset)
|
|
||||||
if self.args.local_rank == -1
|
|
||||||
else DistributedSampler(self.train_dataset)
|
|
||||||
)
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user