@@ -57,7 +57,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
|
||||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
|
||||||
input_ids = tokenzier("Hello, I'm a language model". return_tensors="pt").to("cuda")
|
input_ids = tokenizer("Hello, I'm a language model", return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
output = model.generate(**input_ids, cache_implementation="static")
|
output = model.generate(**input_ids, cache_implementation="static")
|
||||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user