Update expected values (after switching to A10) - part 5 (#39205)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import pytest
|
||||
|
||||
from transformers import Starcoder2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -148,10 +149,20 @@ class Starcoder2IntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_starcoder2_batched_generation_4bit(self):
|
||||
EXPECTED_TEXT = [
|
||||
'Hello my name is Younes and I am a student at the University of Maryland. I am currently working on a project that is related to the topic of "How to make a game". I am currently working on a project',
|
||||
'def hello_world():\n\treturn "Hello World"\n\n@app.route(\'/hello/<name>\')\ndef hello_name(name):\n\treturn "Hello " + name\n\n@app.route',
|
||||
]
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [
|
||||
'Hello my name is Younes and I am a student at the University of Maryland. I am currently working on a project that is related to the topic of "How to make a game". I am currently working on a project',
|
||||
'def hello_world():\n\treturn "Hello World"\n\n@app.route(\'/hello/<name>\')\ndef hello_name(name):\n\treturn "Hello " + name\n\n@app.route',
|
||||
],
|
||||
("cuda", 8): [
|
||||
"Hello my name is Younes and I am a student at the University of Maryland. I am currently working on a project that is aimed at creating a new way of learning. I am hoping to create a new way of",
|
||||
'def hello_world():\n\treturn "Hello World"\n\n@app.route(\'/hello/<name>\')\ndef hello_name(name):\n\treturn "Hello " + name\n\n@app.route',
|
||||
],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT = expectations.get_expectation()
|
||||
|
||||
model_id = "bigcode/starcoder2-7b"
|
||||
|
||||
model = Starcoder2ForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
Reference in New Issue
Block a user