fix the shuffle agrument usage and the default (#6307)

This commit is contained in:
Stas Bekman
2020-08-06 17:32:28 -07:00
committed by GitHub
parent ffceef2042
commit 175cd45e13
2 changed files with 3 additions and 2 deletions

View File

@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
"Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server
@@ -95,7 +95,7 @@ class GLUETransformer(BaseTransformer):
return DataLoader(
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels),
batch_size=batch_size,
shuffle=True,
shuffle=shuffle,
)
def validation_step(self, batch, batch_idx):