Fix PL token classification examples (#6682)

This commit is contained in:
vblagoje
2020-08-24 11:30:06 -04:00
committed by GitHub
parent a573777901
commit dd522da004
3 changed files with 14 additions and 8 deletions

View File

@@ -86,7 +86,7 @@ class NERTransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
def get_dataloader(self, mode: int, batch_size: int) -> DataLoader:
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader:
"Load datasets. Called after prepare data."
cached_features_file = self._feature_file(mode)
logger.info("Loading features from cached file %s", cached_features_file)