[Fix] text-classification PL example (#6027)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Bhashithe Abeysinghe
2020-08-06 15:46:43 -04:00
committed by GitHub
parent eb2bd8d6eb
commit ffceef2042
3 changed files with 14 additions and 6 deletions

View File

@@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python3 run_pl_glue.py --data_dir $DATA_DIR \
python3 run_pl_glue.py --gpus 1 --data_dir $DATA_DIR \
--task $TASK \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \

View File

@@ -3,6 +3,7 @@ import glob
import logging
import os
import time
from argparse import Namespace
import numpy as np
import torch
@@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
mode = "sequence-classification"
def __init__(self, hparams):
if type(hparams) == dict:
hparams = Namespace(**hparams)
hparams.glue_output_mode = glue_output_modes[hparams.task]
num_labels = glue_tasks_num_labels[hparams.task]
@@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer):
outputs = self(**inputs)
loss = outputs[0]
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
tensorboard_logs = {"loss": loss}
return {"loss": loss, "log": tensorboard_logs}
def prepare_data(self):
@@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
def load_dataset(self, mode, batch_size):
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
"Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server