@@ -328,14 +328,14 @@ def _compute_llama3_parameters(
|
|||||||
wavelen = 2 * math.pi / inv_freq
|
wavelen = 2 * math.pi / inv_freq
|
||||||
# wavelen < high_freq_wavelen: do nothing
|
# wavelen < high_freq_wavelen: do nothing
|
||||||
# wavelen > low_freq_wavelen: divide by factor
|
# wavelen > low_freq_wavelen: divide by factor
|
||||||
inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
||||||
# otherwise: interpolate between the two, using a smooth factor
|
# otherwise: interpolate between the two, using a smooth factor
|
||||||
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new
|
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
||||||
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
||||||
inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new)
|
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||||
|
|
||||||
return inv_freq, attention_factor
|
return inv_freq_llama, attention_factor
|
||||||
|
|
||||||
|
|
||||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pytest
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
|
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -718,6 +718,34 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_read_token
|
||||||
|
def test_llama_3_1_hard(self):
|
||||||
|
"""
|
||||||
|
An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences
|
||||||
|
from llama 3.1.'s RoPE can be detected
|
||||||
|
"""
|
||||||
|
EXPECTED_TEXT = (
|
||||||
|
"Tell me about the french revolution. The french revolution was a period of radical social and political "
|
||||||
|
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
|
||||||
|
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
|
||||||
|
"First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative "
|
||||||
|
"assembly that had not met since 1614. The Third Estate, which represented the common people, "
|
||||||
|
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
|
||||||
|
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
input_text = ["Tell me about the french revolution."]
|
||||||
|
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
|
||||||
|
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
self.assertEqual(generated_text, EXPECTED_TEXT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_7b_logits_bf16(self):
|
def test_model_7b_logits_bf16(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user