Stop passing None to compile() in TF examples (#29597)
* Fix examples to stop passing None to compile(), rework example invocation for run_text_classification.py * Add Amy's fix
This commit is contained in:
@@ -509,7 +509,7 @@ def main():
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
).with_options(dataset_options)
|
).with_options(dataset_options)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = "sgd" # Just write anything because we won't be using it
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
eval_dataset = model.prepare_tf_dataset(
|
eval_dataset = model.prepare_tf_dataset(
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = "sgd" # Just write anything because we won't be using it
|
||||||
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
# Transformers models compute the right loss for their task by default when labels are passed, and will
|
||||||
# use this for training unless you specify your own loss function in compile().
|
# use this for training unless you specify your own loss function in compile().
|
||||||
model.compile(optimizer=optimizer, metrics=["accuracy"], jit_compile=training_args.xla)
|
model.compile(optimizer=optimizer, metrics=["accuracy"], jit_compile=training_args.xla)
|
||||||
|
|||||||
@@ -706,7 +706,8 @@ def main():
|
|||||||
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
model.compile(optimizer=optimizer, jit_compile=training_args.xla, metrics=["accuracy"])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model.compile(optimizer=None, jit_compile=training_args.xla, metrics=["accuracy"])
|
# Optimizer doesn't matter as it won't be used anyway
|
||||||
|
model.compile(optimizer="sgd", jit_compile=training_args.xla, metrics=["accuracy"])
|
||||||
training_dataset = None
|
training_dataset = None
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = "sgd" # Just write anything because we won't be using it
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,10 @@ python run_text_classification.py \
|
|||||||
--train_file training_data.json \
|
--train_file training_data.json \
|
||||||
--validation_file validation_data.json \
|
--validation_file validation_data.json \
|
||||||
--output_dir output/ \
|
--output_dir output/ \
|
||||||
--test_file data_to_predict.json
|
--test_file data_to_predict.json \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--do_predict
|
||||||
```
|
```
|
||||||
|
|
||||||
## run_glue.py
|
## run_glue.py
|
||||||
|
|||||||
@@ -477,7 +477,7 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = "adam" # Just write anything because we won't be using it
|
optimizer = "sgd" # Just write anything because we won't be using it
|
||||||
if is_regression:
|
if is_regression:
|
||||||
metrics = []
|
metrics = []
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -526,7 +526,7 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = "sgd" # Just use any default
|
||||||
if is_regression:
|
if is_regression:
|
||||||
metrics = []
|
metrics = []
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -584,7 +584,7 @@ def main():
|
|||||||
adam_global_clipnorm=training_args.max_grad_norm,
|
adam_global_clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = None
|
optimizer = "sgd" # Just write anything because we won't be using it
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Metric and postprocessing
|
# region Metric and postprocessing
|
||||||
|
|||||||
Reference in New Issue
Block a user