DeepSpeed: hardcode torch.arange dtype on float usage to avoid incorrect initialization (#28760)

This commit is contained in:
Joao Gante
2024-01-31 14:39:07 +00:00
committed by GitHub
parent f7076cd346
commit beb2a09687
50 changed files with 192 additions and 118 deletions

View File

@@ -25,6 +25,7 @@ import datasets
from parameterized import parameterized
import tests.trainer.test_trainer
import transformers
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.integrations.deepspeed import (
@@ -53,6 +54,8 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_dev
if is_torch_available():
import torch
from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
@@ -70,6 +73,7 @@ DEFAULT_MASTER_PORT = "10999"
T5_SMALL = "t5-small"
T5_TINY = "patrickvonplaten/t5-tiny-random"
GPT2_TINY = "sshleifer/tiny-gpt2"
GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
def load_json(path):
@@ -297,6 +301,74 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
)
def test_arange_bf16(self):
# Tests that configuring DeepSpeed with 16 bits does not cause float `torch.arange()` tensors to be cast down.
# NOTE -- this assumes that the function calls have the following downcast-preventing pattern, i.e.
# `torch.arange(...,dtype=torch.int64)` followed by a cast like `.to(torch.float32)`. 🚨 If this pattern is
# NOT applied (e.g. `torch.arange(...,dtype=torch.float32)` is used), DeepSpeed can automatically cast it down
# at init time. See https://github.com/huggingface/transformers/issues/28685 for more info.
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
"bf16": {"enabled": True},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = AutoModel.from_pretrained(GPTJ_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
# The model weights are in BF16 as per deepspeed config
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
good_deepspeed_sin_cos = model.h[0].attn.embed_positions
# Monkeypatches the function that creates RoPE embeddings using the INCORRECT torch.arange() pattern, and
# then recreates the model
def bad_deepspeed_create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
# Incorrect pattern here: torch.arange has dtype=torch.float32 as its argument, and it will automatically
# converted to BF16 by DeepSpeed
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=inv_freq.dtype), inv_freq)
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
good_deepspeed_create_sinusoidal_positions = transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
transformers.models.gptj.modeling_gptj.create_sinusoidal_positions = bad_deepspeed_create_sinusoidal_positions
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = AutoModel.from_pretrained(GPTJ_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
bad_deepspeed_sin_cos = model.h[0].attn.embed_positions
# Compares the two values: the two sets of values are different, and the correct one matches the torch
# (i.e. outside DeepSpeed) version.
good_torch_sin_cos = good_deepspeed_create_sinusoidal_positions(
model.config.max_position_embeddings, model.config.rotary_dim
)
self.assertFalse(torch.allclose(good_deepspeed_sin_cos, bad_deepspeed_sin_cos))
self.assertTrue(torch.allclose(good_torch_sin_cos, good_deepspeed_sin_cos.cpu()))
# Finally, we can see that the incorrect pattern is okay on vanilla torch, demostrating that this issue is
# exclusive to DeepSpeed
bad_torch_sin_cos = bad_deepspeed_create_sinusoidal_positions(
model.config.max_position_embeddings, model.config.rotary_dim
)
self.assertTrue(torch.allclose(bad_torch_sin_cos, good_torch_sin_cos))
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):