From e6b811f0a7a3174d0e62c2cdf876230510031319 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 27 Aug 2020 09:22:18 -0700 Subject: [PATCH] [testing] replace hardcoded paths to allow running tests from anywhere (#6523) * [testing] replace hardcoded paths to allow running tests from anywhere * fix the merge conflict --- src/transformers/testing_utils.py | 10 ++++++++++ tests/test_tokenization_fast.py | 4 ++-- tests/test_trainer.py | 6 +++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 17d6b0e5bf..92117ca2a1 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1,3 +1,4 @@ +import inspect import os import re import shutil @@ -144,6 +145,15 @@ def require_torch_and_cuda(test_case): return test_case +def get_tests_dir(): + """ + returns the full path to the `tests` dir, so that the tests can be invoked from anywhere + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + return os.path.abspath(os.path.dirname(caller__file__)) + + # # Helper functions for dealing with testing text outputs # The original code came from: diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index 4dcf0bf896..a0a9d49646 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -15,7 +15,7 @@ from transformers import ( TransfoXLTokenizer, is_torch_available, ) -from transformers.testing_utils import require_torch +from transformers.testing_utils import get_tests_dir, require_torch from transformers.tokenization_distilbert import DistilBertTokenizerFast from transformers.tokenization_openai import OpenAIGPTTokenizerFast from transformers.tokenization_roberta import RobertaTokenizerFast @@ -42,7 +42,7 @@ class CommonFastTokenizerTest(unittest.TestCase): TOKENIZERS_CLASSES = frozenset([]) def setUp(self) -> None: - with open("tests/fixtures/sample_text.txt", encoding="utf-8") as f_data: + with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data: self._data = f_data.read().replace("\n\n", "\n").strip() def test_all_tokenizers(self): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e8dc86ea6f..034cc552f9 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -4,7 +4,7 @@ import nlp import numpy as np from transformers import AutoTokenizer, TrainingArguments, is_torch_available -from transformers.testing_utils import require_torch +from transformers.testing_utils import get_tests_dir, require_torch if is_torch_available(): @@ -20,7 +20,7 @@ if is_torch_available(): ) -PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" +PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" class RegressionDataset: @@ -262,7 +262,7 @@ class TrainerIntegrationTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) data_args = GlueDataTrainingArguments( - task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True + task_name="mrpc", data_dir=f"{get_tests_dir()}/fixtures/tests_samples/MRPC", overwrite_cache=True ) eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")