From 224bde91caff4ccfd12277ab5e9bf97c61e22ee9 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 10 Jun 2022 18:50:29 +0200 Subject: [PATCH] Avoid GPU OOM for a TF Rag test (#17638) Co-authored-by: ydshieh --- tests/models/rag/test_modeling_tf_rag.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/models/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py index d9050acb63..314ce099ba 100644 --- a/tests/models/rag/test_modeling_tf_rag.py +++ b/tests/models/rag/test_modeling_tf_rag.py @@ -838,13 +838,6 @@ class TFRagModelIntegrationTests(unittest.TestCase): input_ids = input_dict.input_ids attention_mask = input_dict.attention_mask - output_ids = rag_token.generate( - input_ids, - attention_mask=attention_mask, - ) - - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - EXPECTED_OUTPUTS = [ " albert einstein", " september 22, 2017", @@ -855,7 +848,21 @@ class TFRagModelIntegrationTests(unittest.TestCase): " 7.1. 2", " 13", ] - self.assertListEqual(outputs, EXPECTED_OUTPUTS) + + # Split into 2 batches of 4 examples to avoid GPU OOM. + output_ids = rag_token.generate( + input_ids[:4], + attention_mask=attention_mask[:4], + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4]) + + output_ids = rag_token.generate( + input_ids[4:], + attention_mask=attention_mask[4:], + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:]) @slow def test_rag_sequence_generate_batch(self):