Fix RealmModelIntegrationTest.test_inference_open_qa (#21136)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-01-16 15:09:52 +01:00
committed by GitHub
parent a5327c6a9a
commit a45914193a

View File

@@ -480,15 +480,12 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_open_qa(self): def test_inference_open_qa(self):
from transformers.models.realm.retrieval_realm import RealmRetriever from transformers.models.realm.retrieval_realm import RealmRetriever
config = RealmConfig()
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa") tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa") retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained( model = RealmForOpenQA.from_pretrained(
"google/realm-orqa-nq-openqa", "google/realm-orqa-nq-openqa",
retriever=retriever, retriever=retriever,
config=config,
) )
question = "Who is the pioneer in modern computer science?" question = "Who is the pioneer in modern computer science?"