[Mamba doc] Post merge updates (#29472)

* post merge update

* nit

* oups
This commit is contained in:
Arthur
2024-03-11 19:46:24 +11:00
committed by GitHub
parent 0290ec19c9
commit 4f27ee936a
3 changed files with 14 additions and 17 deletions

View File

@@ -406,15 +406,15 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch
class MambaIntegrationTests(unittest.TestCase):
def setUp(self):
self.model_id = "ArthurZ/mamba-2.8b"
self.model_id = "state-spaces/mamba-2.8b-hf"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@parameterized.expand([(torch_device,), ("cpu",)])
def test_simple_generate(self, device):
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
model.to(device)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
@@ -444,7 +444,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a newbie to the world"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
@@ -457,7 +457,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is\n\nI am a\n\nI am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-790m-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
@@ -470,7 +470,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
@@ -483,7 +483,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a new member of this forum. I am a retired Marine and I am a member of the Marine Corps League. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=30)
output_sentence = self.tokenizer.decode(output[0].tolist())