DeepSpeed: hardcode torch.arange dtype on float usage to avoid incorrect initialization (#28760)
This commit is contained in:
@@ -25,6 +25,7 @@ import datasets
|
||||
from parameterized import parameterized
|
||||
|
||||
import tests.trainer.test_trainer
|
||||
import transformers
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.integrations.deepspeed import (
|
||||
@@ -53,6 +54,8 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_dev
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from tests.trainer.test_trainer import ( # noqa
|
||||
RegressionModelConfig,
|
||||
RegressionPreTrainedModel,
|
||||
@@ -70,6 +73,7 @@ DEFAULT_MASTER_PORT = "10999"
|
||||
T5_SMALL = "t5-small"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
GPT2_TINY = "sshleifer/tiny-gpt2"
|
||||
GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
|
||||
|
||||
|
||||
def load_json(path):
|
||||
@@ -297,6 +301,74 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
|
||||
)
|
||||
|
||||
def test_arange_bf16(self):
|
||||
# Tests that configuring DeepSpeed with 16 bits does not cause float `torch.arange()` tensors to be cast down.
|
||||
# NOTE -- this assumes that the function calls have the following downcast-preventing pattern, i.e.
|
||||
# `torch.arange(...,dtype=torch.int64)` followed by a cast like `.to(torch.float32)`. 🚨 If this pattern is
|
||||
# NOT applied (e.g. `torch.arange(...,dtype=torch.float32)` is used), DeepSpeed can automatically cast it down
|
||||
# at init time. See https://github.com/huggingface/transformers/issues/28685 for more info.
|
||||
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
"bf16": {"enabled": True},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertTrue(dschf.is_zero3())
|
||||
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = AutoModel.from_pretrained(GPTJ_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
# The model weights are in BF16 as per deepspeed config
|
||||
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
|
||||
good_deepspeed_sin_cos = model.h[0].attn.embed_positions
|
||||
|
||||
# Monkeypatches the function that creates RoPE embeddings using the INCORRECT torch.arange() pattern, and
|
||||
# then recreates the model
|
||||
def bad_deepspeed_create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
||||
# Incorrect pattern here: torch.arange has dtype=torch.float32 as its argument, and it will automatically
|
||||
# converted to BF16 by DeepSpeed
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=inv_freq.dtype), inv_freq)
|
||||
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
||||
|
||||
good_deepspeed_create_sinusoidal_positions = transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
|
||||
transformers.models.gptj.modeling_gptj.create_sinusoidal_positions = bad_deepspeed_create_sinusoidal_positions
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model = AutoModel.from_pretrained(GPTJ_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
|
||||
bad_deepspeed_sin_cos = model.h[0].attn.embed_positions
|
||||
|
||||
# Compares the two values: the two sets of values are different, and the correct one matches the torch
|
||||
# (i.e. outside DeepSpeed) version.
|
||||
good_torch_sin_cos = good_deepspeed_create_sinusoidal_positions(
|
||||
model.config.max_position_embeddings, model.config.rotary_dim
|
||||
)
|
||||
self.assertFalse(torch.allclose(good_deepspeed_sin_cos, bad_deepspeed_sin_cos))
|
||||
self.assertTrue(torch.allclose(good_torch_sin_cos, good_deepspeed_sin_cos.cpu()))
|
||||
|
||||
# Finally, we can see that the incorrect pattern is okay on vanilla torch, demostrating that this issue is
|
||||
# exclusive to DeepSpeed
|
||||
bad_torch_sin_cos = bad_deepspeed_create_sinusoidal_positions(
|
||||
model.config.max_position_embeddings, model.config.rotary_dim
|
||||
)
|
||||
self.assertTrue(torch.allclose(bad_torch_sin_cos, good_torch_sin_cos))
|
||||
|
||||
|
||||
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user