Fix: Fix FalconMamba training issues due to incompatible kernels (#33195)

* fix FM training kernels

* fix copies

* fix copies

* propagate to slow path

* make it BC

* add comment

* fix test
This commit is contained in:
Younes Belkada
2024-09-05 13:55:08 +04:00
committed by GitHub
parent 43df47d8e7
commit 47b096412d
6 changed files with 597 additions and 11 deletions

View File

@@ -524,3 +524,32 @@ class FalconMambaIntegrationTests(unittest.TestCase):
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_training_kernel(self):
model_id = "tiiuae/falcon-mamba-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer.pad_token_id = tokenizer.eos_token_id
text = "Hello today"
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
with torch.no_grad():
logits = torch.argmax(model(**inputs).logits, dim=-1)
out_no_training = tokenizer.batch_decode(logits)
model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)
out_training = tokenizer.batch_decode(next_token)
# Just verify backward works
loss = (1 - lm_logits).mean()
loss.backward()
self.assertEqual(out_training, out_no_training)