[s2strainer] fix eval dataset loading (#7477)
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
from transformers.trainer_utils import EvaluationStrategy
|
||||||
from utils import (
|
from utils import (
|
||||||
LegacySeq2SeqDataset,
|
LegacySeq2SeqDataset,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
@@ -350,7 +351,7 @@ def main():
|
|||||||
max_source_length=data_args.max_source_length,
|
max_source_length=data_args.max_source_length,
|
||||||
prefix=model.config.prefix or "",
|
prefix=model.config.prefix or "",
|
||||||
)
|
)
|
||||||
if training_args.do_eval
|
if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_dataset = (
|
test_dataset = (
|
||||||
|
|||||||
Reference in New Issue
Block a user