@@ -28,7 +28,6 @@ from accelerate.utils import write_basic_config
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
backend_device_count,
|
backend_device_count,
|
||||||
is_torch_fp16_available_on_device,
|
|
||||||
run_command,
|
run_command,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -93,9 +92,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--with_tracking
|
--with_tracking
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_torch_fp16_available_on_device(torch_device):
|
|
||||||
testargs.append("--fp16")
|
|
||||||
|
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
@@ -325,9 +321,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--checkpointing_steps 1
|
--checkpointing_steps 1
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_torch_fp16_available_on_device(torch_device):
|
|
||||||
testargs.append("--fp16")
|
|
||||||
|
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
# The base model scores a 25%
|
# The base model scores a 25%
|
||||||
|
|||||||
Reference in New Issue
Block a user