[lightning_base] fix s2s logging, only make train_loader once (#6404)
This commit is contained in:
@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
"Load datasets. Called after prepare data."
|
||||
|
||||
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
||||
@@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
|
||||
type=int,
|
||||
help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
|
||||
Reference in New Issue
Block a user