Fix FbgemmFp8Linear not preserving tensor shape (#33239)

* add tests for linear shape behavior

* fix linear shape behavior

ended up adding the reshape at the end, after f8f8bf16_rowwise, because adding
it directly after quantize_fp8_per_row caused f8f8bf16_rowwise to drop the
seq_len dimension. (i.e., (17, 23, 1014) -> (17, 1024))

* save shape up front + comment
This commit is contained in:
Theia Vogel
2024-09-11 04:26:44 -07:00
committed by GitHub
parent 781bbc4d98
commit e719b65c31
2 changed files with 34 additions and 0 deletions

View File

@@ -268,3 +268,34 @@ class FbgemmFp8Test(unittest.TestCase):
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_gpu
@require_accelerate
@require_fbgemm_gpu
class FbgemmFp8LinearTest(unittest.TestCase):
def test_linear_preserves_shape(self):
"""
Test that FbgemmFp8Linear preserves shape when in_features == out_features.
"""
from transformers.integrations import FbgemmFp8Linear
with init_empty_weights(include_buffers=True):
linear = FbgemmFp8Linear(1024, 1024, True)
x = torch.rand((17, 23, 1024))
x_ = linear(x)
self.assertEqual(x_.shape, x.shape)
def test_linear_with_diff_feature_size_preserves_shape(self):
"""
Test that FbgemmFp8Linear generates the correct shape when in_features != out_features.
"""
from transformers.integrations import FbgemmFp8Linear
with init_empty_weights(include_buffers=True):
linear = FbgemmFp8Linear(1024, 2048, True)
x = torch.rand((17, 23, 1024))
x_ = linear(x)
self.assertEqual(x_.shape, (17, 23, 2048))