Add torch.compile for Mistral (#30642)
* first version * fix sliding window * fix style * add sliding window cache * fix style * address comments * fix test * fix style * move sliding window check inside cache init * revert changes on irrelevant files & add comment on SlidingWindowCache * address comments & fix style fix style * update causal mask * [run-slow] mistral * [run-slow] mistral * [run-slow] mistral * [run-slow] mistral * [run-slow] mistral * [run-slow] llama * [run-slow] mistral * [run-slow] mistral * [run-slow] mistral * revert CI from a10 to t4 * wrap up
This commit is contained in:
@@ -12,14 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Mistral model. """
|
||||
|
||||
"""Testing suite for the PyTorch Mistral model."""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
8: [
|
||||
"My favourite condiment is 100% ketchup. I love it on everything. "
|
||||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
|
||||
],
|
||||
7: [
|
||||
"My favourite condiment is 100% ketchup. I love it on everything. "
|
||||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
|
||||
],
|
||||
}
|
||||
|
||||
prompts = ["My favourite condiment is "]
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||
|
||||
# Sliding Window Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
forward_function = model.forward
|
||||
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
|
||||
# Sliding Window Cache + compile
|
||||
torch._dynamo.reset()
|
||||
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user