[Fix] text-classification PL example (#6027)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
committed by
GitHub
parent
eb2bd8d6eb
commit
ffceef2042
@@ -3,6 +3,7 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
|
||||
mode = "sequence-classification"
|
||||
|
||||
def __init__(self, hparams):
|
||||
if type(hparams) == dict:
|
||||
hparams = Namespace(**hparams)
|
||||
hparams.glue_output_mode = glue_output_modes[hparams.task]
|
||||
num_labels = glue_tasks_num_labels[hparams.task]
|
||||
|
||||
@@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer):
|
||||
outputs = self(**inputs)
|
||||
loss = outputs[0]
|
||||
|
||||
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||
tensorboard_logs = {"loss": loss}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def prepare_data(self):
|
||||
@@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer):
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
def load_dataset(self, mode, batch_size):
|
||||
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
|
||||
"Load datasets. Called after prepare data."
|
||||
|
||||
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
||||
|
||||
Reference in New Issue
Block a user