Revert frozen training arguments (#25903)
* Revert frozen training arguments * TODO
This commit is contained in:
@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
|
||||
b: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# save resources not dealing with reporting (also avoids the warning when it's not set)
|
||||
self.report_to = []
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
class RepeatDataset:
|
||||
@@ -529,8 +529,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
||||
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
|
||||
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
|
||||
trainer.args.seed = 314
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
@@ -206,14 +205,7 @@ if __name__ == "__main__":
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
|
||||
trainer = Trainer(
|
||||
model=DummyModel(),
|
||||
args=training_args,
|
||||
data_collator=DummyDataCollator(),
|
||||
eval_dataset=dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.args.eval_accumulation_steps = 2
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
logger.info(metrics)
|
||||
@@ -227,22 +219,15 @@ if __name__ == "__main__":
|
||||
logger.error(p.metrics)
|
||||
exit(1)
|
||||
|
||||
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
|
||||
trainer = Trainer(
|
||||
model=DummyModel(),
|
||||
args=training_args,
|
||||
data_collator=DummyDataCollator(),
|
||||
eval_dataset=dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.args.eval_accumulation_steps = None
|
||||
|
||||
# Check that `dispatch_batches=False` will work on a finite iterable dataset
|
||||
|
||||
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
|
||||
|
||||
model = RegressionModel()
|
||||
training_args = dataclasses.replace(
|
||||
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
|
||||
)
|
||||
training_args.per_device_train_batch_size = 1
|
||||
training_args.max_steps = 1
|
||||
training_args.dispatch_batches = False
|
||||
trainer = Trainer(model, training_args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user