From dd522da00424e5508d47fdcb7872ed2a59a9ed89 Mon Sep 17 00:00:00 2001 From: vblagoje Date: Mon, 24 Aug 2020 11:30:06 -0400 Subject: [PATCH] Fix PL token classification examples (#6682) --- examples/token-classification/run.sh | 9 ++++++--- examples/token-classification/run_pl.sh | 11 +++++++---- examples/token-classification/run_pl_ner.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/token-classification/run.sh b/examples/token-classification/run.sh index 9ff10ca36d..f5cbf0d50e 100755 --- a/examples/token-classification/run.sh +++ b/examples/token-classification/run.sh @@ -1,8 +1,11 @@ -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \ +## The relevant files are currently on a shared Google +## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J +## Monitor for changes and eventually migrate to nlp dataset +curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \ +curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \ +curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp export MAX_LENGTH=128 diff --git a/examples/token-classification/run_pl.sh b/examples/token-classification/run_pl.sh index ecbd5d3b4f..5abcd981bf 100755 --- a/examples/token-classification/run_pl.sh +++ b/examples/token-classification/run_pl.sh @@ -3,11 +3,14 @@ # for seqeval metrics import pip install -r ../requirements.txt -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \ +## The relevant files are currently on a shared Google +## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J +## Monitor for changes and eventually migrate to nlp dataset +curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \ +curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp -curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \ +curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \ | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp export MAX_LENGTH=128 @@ -29,7 +32,6 @@ mkdir -p $OUTPUT_DIR export PYTHONPATH="../":"${PYTHONPATH}" python3 run_pl_ner.py --data_dir ./ \ ---model_type bert \ --labels ./labels.txt \ --model_name_or_path $BERT_MODEL \ --output_dir $OUTPUT_DIR \ @@ -37,5 +39,6 @@ python3 run_pl_ner.py --data_dir ./ \ --num_train_epochs $NUM_EPOCHS \ --train_batch_size $BATCH_SIZE \ --seed $SEED \ +--gpus 1 \ --do_train \ --do_predict diff --git a/examples/token-classification/run_pl_ner.py b/examples/token-classification/run_pl_ner.py index bcdf2ba5df..c82cff74d8 100644 --- a/examples/token-classification/run_pl_ner.py +++ b/examples/token-classification/run_pl_ner.py @@ -86,7 +86,7 @@ class NERTransformer(BaseTransformer): logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) - def get_dataloader(self, mode: int, batch_size: int) -> DataLoader: + def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader: "Load datasets. Called after prepare data." cached_features_file = self._feature_file(mode) logger.info("Loading features from cached file %s", cached_features_file)