Fix: Fix FalconMamba training issues due to incompatible kernels (#33195)
* fix FM training kernels * fix copies * fix copies * propagate to slow path * make it BC * add comment * fix test
This commit is contained in:
@@ -524,3 +524,32 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_training_kernel(self):
|
||||
model_id = "tiiuae/falcon-mamba-7b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello today"
|
||||
|
||||
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = torch.argmax(model(**inputs).logits, dim=-1)
|
||||
|
||||
out_no_training = tokenizer.batch_decode(logits)
|
||||
|
||||
model.train()
|
||||
lm_logits = model(**inputs).logits
|
||||
next_token = torch.argmax(lm_logits, dim=-1)
|
||||
|
||||
out_training = tokenizer.batch_decode(next_token)
|
||||
|
||||
# Just verify backward works
|
||||
loss = (1 - lm_logits).mean()
|
||||
loss.backward()
|
||||
|
||||
self.assertEqual(out_training, out_no_training)
|
||||
|
||||
Reference in New Issue
Block a user