Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller
2023-09-01 11:24:12 -04:00
committed by GitHub
parent 69c5b8f186
commit be0e189bd3
9 changed files with 31 additions and 58 deletions

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import Dict
import numpy as np
@@ -206,14 +205,7 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
trainer.args.eval_accumulation_steps = 2
metrics = trainer.evaluate()
logger.info(metrics)
@@ -227,22 +219,15 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
trainer.args.eval_accumulation_steps = None
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
training_args = dataclasses.replace(
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
)
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()