From 184ef8ecd05ac783827b196e8d15403820efedf9 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 12 Mar 2021 06:16:40 -0500 Subject: [PATCH] TensorFlow tests: having from_pt set to True requires torch to be installed. (#10664) * TF model exists for Blenderbot 400M * Marian * RAG --- tests/test_modeling_tf_blenderbot.py | 2 +- tests/test_modeling_tf_marian.py | 2 +- tests/test_modeling_tf_rag.py | 14 +++++--------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 050a223f0e..aa672a970c 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -309,7 +309,7 @@ class TFBlenderbot400MIntegrationTests(unittest.TestCase): @cached_property def model(self): - model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name) return model @slow diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index e4ccb28f00..55175f9d66 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -350,7 +350,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase): @cached_property def model(self): warnings.simplefilter("error") - model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True) + model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name) assert isinstance(model, TFMarianMTModel) c = model.config self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]]) diff --git a/tests/test_modeling_tf_rag.py b/tests/test_modeling_tf_rag.py index ec96aee8f8..8dd1cb39d1 100644 --- a/tests/test_modeling_tf_rag.py +++ b/tests/test_modeling_tf_rag.py @@ -562,7 +562,7 @@ class TFRagModelIntegrationTests(unittest.TestCase): ) def token_model_nq_checkpoint(self, retriever): - return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", from_pt=True, retriever=retriever) + return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) def get_rag_config(self): question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") @@ -799,7 +799,7 @@ class TFRagModelIntegrationTests(unittest.TestCase): def test_rag_token_greedy_search(self): tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) - rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) + rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) # check first two questions input_dict = tokenizer( @@ -833,7 +833,7 @@ class TFRagModelIntegrationTests(unittest.TestCase): # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) - rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True) + rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, @@ -877,9 +877,7 @@ class TFRagModelIntegrationTests(unittest.TestCase): retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) - rag_sequence = TFRagSequenceForGeneration.from_pretrained( - "facebook/rag-sequence-nq", retriever=retriever, from_pt=True - ) + rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, @@ -923,9 +921,7 @@ class TFRagModelIntegrationTests(unittest.TestCase): retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ) - rag_sequence = TFRagSequenceForGeneration.from_pretrained( - "facebook/rag-sequence-nq", retriever=retriever, from_pt=True - ) + rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) input_dict = tokenizer( self.test_data_questions, return_tensors="tf",