Fix FbgemmFp8Linear not preserving tensor shape (#33239)
* add tests for linear shape behavior * fix linear shape behavior ended up adding the reshape at the end, after f8f8bf16_rowwise, because adding it directly after quantize_fp8_per_row caused f8f8bf16_rowwise to drop the seq_len dimension. (i.e., (17, 23, 1014) -> (17, 1024)) * save shape up front + comment
This commit is contained in:
@@ -45,6 +45,8 @@ class FbgemmFp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
num_tokens = None
|
num_tokens = None
|
||||||
|
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
|
||||||
|
output_shape = (*x.shape[:-1], -1)
|
||||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
@@ -60,6 +62,7 @@ class FbgemmFp8Linear(torch.nn.Module):
|
|||||||
output = output + self.bias if self.bias is not None else output
|
output = output + self.bias if self.bias is not None else output
|
||||||
# Hacky for now, we have the output to the device of x
|
# Hacky for now, we have the output to the device of x
|
||||||
output = output.to(x.device)
|
output = output.to(x.device)
|
||||||
|
output = output.reshape(output_shape)
|
||||||
del x_quantized, x_scale
|
del x_quantized, x_scale
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -268,3 +268,34 @@ class FbgemmFp8Test(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_accelerate
|
||||||
|
@require_fbgemm_gpu
|
||||||
|
class FbgemmFp8LinearTest(unittest.TestCase):
|
||||||
|
def test_linear_preserves_shape(self):
|
||||||
|
"""
|
||||||
|
Test that FbgemmFp8Linear preserves shape when in_features == out_features.
|
||||||
|
"""
|
||||||
|
from transformers.integrations import FbgemmFp8Linear
|
||||||
|
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
linear = FbgemmFp8Linear(1024, 1024, True)
|
||||||
|
x = torch.rand((17, 23, 1024))
|
||||||
|
|
||||||
|
x_ = linear(x)
|
||||||
|
self.assertEqual(x_.shape, x.shape)
|
||||||
|
|
||||||
|
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||||
|
"""
|
||||||
|
Test that FbgemmFp8Linear generates the correct shape when in_features != out_features.
|
||||||
|
"""
|
||||||
|
from transformers.integrations import FbgemmFp8Linear
|
||||||
|
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
linear = FbgemmFp8Linear(1024, 2048, True)
|
||||||
|
x = torch.rand((17, 23, 1024))
|
||||||
|
|
||||||
|
x_ = linear(x)
|
||||||
|
self.assertEqual(x_.shape, (17, 23, 2048))
|
||||||
|
|||||||
Reference in New Issue
Block a user