Upgrade examples to pl=0.8.1(#5146)
This commit is contained in:
@@ -24,6 +24,7 @@ logger = logging.getLogger()
|
||||
FP16_EVER = False
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
|
||||
MSG = "T5 is broken at the moment"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
|
||||
|
||||
def make_test_data_dir():
|
||||
@@ -92,7 +94,6 @@ def make_test_data_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=False,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_fp16(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=3.0,
|
||||
freeze_encoder=True,
|
||||
gpus=1,
|
||||
fp16=FP16_EVER,
|
||||
fp16_opt_level="O1",
|
||||
fp16=FP16_EVER,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_eval_fp16(self):
|
||||
def test_bdc_t5_train(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
gpus=1 if torch.cuda.is_available() else 0,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name=None,
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_train_fp16(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
model_name_or_path=T5_TINY,
|
||||
do_train=True,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name=T5_TINY,
|
||||
no_teacher=True,
|
||||
alpha_hid=2.0,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_checkpointing(self):
|
||||
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_bdc_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher="patrickvonplaten/t5-tiny-random",
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_t5_eval(self):
|
||||
updates = dict(
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
model_type="bart",
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self.assertIn(ckpt_name, contents)
|
||||
self.assertIn("metrics.pkl", contents)
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("val_generations_1.txt", contents)
|
||||
self.assertIn("val_1_results.txt", contents)
|
||||
self.assertIn("val_generations_00001.txt", contents)
|
||||
self.assertIn("val_results_00001.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
# self.assertEqual(len(contents), 15)
|
||||
|
||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||
import pandas as pd
|
||||
|
||||
val_df = pd.DataFrame(metrics["val"])
|
||||
train_df = pd.DataFrame(metrics["train"])
|
||||
test_df = pd.DataFrame(metrics["test"])
|
||||
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
|
||||
self.assertEqual(val_df.shape[0], desired_n_evals) #
|
||||
self.assertEqual(test_df.shape[1], val_df.shape[1])
|
||||
self.assertEqual(train_df.shape[0], 0)
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
|
||||
return model
|
||||
|
||||
|
||||
@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=None, # T5_TINY,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
gpus=0,
|
||||
|
||||
Reference in New Issue
Block a user