* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-23 17:42:46 +02:00
committed by GitHub
parent 9eac19eb59
commit f9be71b34d

View File

@@ -21,6 +21,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
import requests
from transformers import BartTokenizer, T5Tokenizer from transformers import BartTokenizer, T5Tokenizer
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
@@ -49,7 +50,7 @@ T5_SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available() and is_datasets_available() and is_faiss_available(): if is_torch_available() and is_datasets_available() and is_faiss_available():
import faiss import faiss
import torch import torch
from datasets import Dataset from datasets import Dataset, load_dataset
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
@@ -679,6 +680,24 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_tokenizers @require_tokenizers
@require_torch_non_multi_accelerator @require_torch_non_multi_accelerator
class RagModelIntegrationTests(unittest.TestCase): class RagModelIntegrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.dataset_path = cls.temp_dir.name
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
ds.save_to_disk(cls.dataset_path)
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
response = requests.get(url, stream=True)
with open(cls.index_path, "wb") as fp:
fp.write(response.content)
@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
@@ -722,8 +741,9 @@ class RagModelIntegrationTests(unittest.TestCase):
max_combined_length=300, max_combined_length=300,
dataset="wiki_dpr", dataset="wiki_dpr",
dataset_split="train", dataset_split="train",
index_name="exact", index_name="custom",
index_path=None, passages_path=self.dataset_path,
index_path=self.index_path,
use_dummy_dataset=True, use_dummy_dataset=True,
retrieval_vector_size=768, retrieval_vector_size=768,
retrieval_batch_size=8, retrieval_batch_size=8,
@@ -841,8 +861,8 @@ class RagModelIntegrationTests(unittest.TestCase):
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl" EXPECTED_OUTPUT_TEXT_1 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Uses"'
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love" EXPECTED_OUTPUT_TEXT_2 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Ways"'
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@@ -903,7 +923,10 @@ class RagModelIntegrationTests(unittest.TestCase):
def test_rag_sequence_generate_batch(self): def test_rag_sequence_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained( retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417" "facebook/rag-sequence-nq",
index_name="custom",
passages_path=self.dataset_path,
index_path=self.index_path,
) )
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device torch_device
@@ -926,12 +949,13 @@ class RagModelIntegrationTests(unittest.TestCase):
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
# PR #31938 cause the output being changed from `june 22, 2018` to `june 22 , 2018`.
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
" albert einstein", " albert einstein",
" june 22, 2018", " june 22 , 2018",
" amplitude modulation", " amplitude modulation",
" tim besley ( chairman )", " tim besley ( chairman )",
" june 20, 2018", " june 20 , 2018",
" 1980", " 1980",
" 7.0", " 7.0",
" 8", " 8",
@@ -943,9 +967,9 @@ class RagModelIntegrationTests(unittest.TestCase):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained( retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", "facebook/rag-sequence-nq",
index_name="exact", index_name="custom",
use_dummy_dataset=True, passages_path=self.dataset_path,
dataset_revision="b24a417", index_path=self.index_path,
) )
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device torch_device
@@ -981,10 +1005,10 @@ class RagModelIntegrationTests(unittest.TestCase):
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
" albert einstein", " albert einstein",
" june 22, 2018", " june 22 , 2018",
" amplitude modulation", " amplitude modulation",
" tim besley ( chairman )", " tim besley ( chairman )",
" june 20, 2018", " june 20 , 2018",
" 1980", " 1980",
" 7.0", " 7.0",
" 8", " 8",
@@ -995,7 +1019,7 @@ class RagModelIntegrationTests(unittest.TestCase):
def test_rag_token_generate_batch(self): def test_rag_token_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained( retriever = RagRetriever.from_pretrained(
"facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417" "facebook/rag-token-nq", index_name="custom", passages_path=self.dataset_path, index_path=self.index_path
) )
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to( rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
torch_device torch_device
@@ -1023,10 +1047,10 @@ class RagModelIntegrationTests(unittest.TestCase):
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
" albert einstein", " albert einstein",
" september 22, 2017", " september 22 , 2017",
" amplitude modulation", " amplitude modulation",
" stefan persson", " stefan persson",
" april 20, 2018", " april 20 , 2018",
" the 1970s", " the 1970s",
" 7.1. 2", " 7.1. 2",
" 13", " 13",
@@ -1037,6 +1061,24 @@ class RagModelIntegrationTests(unittest.TestCase):
@require_torch @require_torch
@require_retrieval @require_retrieval
class RagModelSaveLoadTests(unittest.TestCase): class RagModelSaveLoadTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.dataset_path = cls.temp_dir.name
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
ds.save_to_disk(cls.dataset_path)
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
response = requests.get(url, stream=True)
with open(cls.index_path, "wb") as fp:
fp.write(response.content)
@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
@@ -1060,8 +1102,9 @@ class RagModelSaveLoadTests(unittest.TestCase):
max_combined_length=300, max_combined_length=300,
dataset="wiki_dpr", dataset="wiki_dpr",
dataset_split="train", dataset_split="train",
index_name="exact", index_name="custom",
index_path=None, passages_path=self.dataset_path,
index_path=self.index_path,
use_dummy_dataset=True, use_dummy_dataset=True,
retrieval_vector_size=768, retrieval_vector_size=768,
retrieval_batch_size=8, retrieval_batch_size=8,