feat: run benchmarks on A100 (#34287)

This commit is contained in:
Luc Georges
2024-10-28 19:33:17 +01:00
committed by GitHub
parent d21dbd1520
commit 6cc4a67b3d
3 changed files with 943 additions and 779 deletions

View File

@@ -16,8 +16,11 @@ env:
jobs:
benchmark:
name: Benchmark
strategy:
matrix:
group: [aws-g5-4xlarge-cache, aws-p4d-24xlarge-plus]
runs-on:
group: aws-g5-4xlarge-cache
group: ${{ matrix.group }}
if: |
(github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-benchmark') )||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
@@ -60,9 +63,13 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
df -h
python3 benchmark/llama.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
# Enable this to see debug logs
# HF_HUB_VERBOSITY: debug
# TRANSFORMERS_VERBOSITY: debug
PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}

View File

@@ -39,7 +39,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -77,7 +77,7 @@
"properties": [
{
"id": "custom.width",
"value": 364
"value": 196
}
]
},
@@ -101,7 +101,7 @@
"properties": [
{
"id": "custom.width",
"value": 708
"value": 581
}
]
},
@@ -113,7 +113,7 @@
"properties": [
{
"id": "custom.width",
"value": 388
"value": 379
}
]
}
@@ -148,7 +148,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT commit_id as commit_id, commit_message, gpu_name FROM benchmarks WHERE branch = '${branch}';",
"rawSql": "SELECT commit_id as commit_id, commit_message, gpu_name, created_at AS date FROM benchmarks WHERE branch = '${branch}' ORDER BY benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -232,7 +232,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -312,7 +312,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -334,6 +334,19 @@
}
],
"title": "First eager forward pass",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -341,7 +354,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -424,7 +437,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -446,6 +459,19 @@
}
],
"title": "Second eager forward pass",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -466,7 +492,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -545,7 +571,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -567,6 +593,19 @@
}
],
"title": "Time to first token",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -574,7 +613,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -653,7 +692,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -675,6 +714,19 @@
}
],
"title": "Time to second token",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -682,7 +734,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -761,7 +813,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -783,6 +835,19 @@
}
],
"title": "Time to third token",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -790,7 +855,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -869,7 +934,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -891,6 +956,19 @@
}
],
"title": "Time to subsequent next tokens mean",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -911,7 +989,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -990,7 +1068,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}'",
"rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1012,6 +1090,19 @@
}
],
"title": "First compile generate",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -1019,7 +1110,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1098,7 +1189,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}';",
"rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1120,6 +1211,19 @@
}
],
"title": "Second compile generate",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -1127,7 +1231,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1206,7 +1310,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}';",
"rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1228,6 +1332,19 @@
}
],
"title": "Third compile generate",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
@@ -1235,7 +1352,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1314,7 +1431,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}';",
"rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1336,11 +1453,24 @@
}
],
"title": "Fourth compile generate",
"transformations": [
{
"id": "sortBy",
"options": {
"fields": {},
"sort": [
{
"field": "time"
}
]
}
}
],
"transparent": true,
"type": "barchart"
},
{
"collapsed": false,
"collapsed": true,
"gridPos": {
"h": 1,
"w": 24,
@@ -1348,15 +1478,12 @@
"y": 64
},
"id": 15,
"panels": [],
"title": "Usage metrics",
"type": "row"
},
"panels": [
{
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1442,7 +1569,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT\n d.cpu_util,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}'",
"rawSql": "SELECT\n d.cpu_util,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}';",
"refId": "A",
"sql": {
"columns": [
@@ -1541,7 +1668,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1627,7 +1754,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT\n b.commit_id,\n d.gpu_util,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}'",
"rawSql": "SELECT\n b.commit_id,\n d.gpu_util,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}';",
"refId": "A",
"sql": {
"columns": [
@@ -1726,7 +1853,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1812,7 +1939,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT d.mem_megabytes, d.time FROM benchmarks AS b JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id WHERE branch = '${branch}'",
"rawSql": "SELECT d.mem_megabytes, d.time FROM benchmarks AS b JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id WHERE branch = '${branch}';",
"refId": "A",
"sql": {
"columns": [
@@ -1911,7 +2038,7 @@
"datasource": {
"default": true,
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"fieldConfig": {
"defaults": {
@@ -1997,7 +2124,7 @@
"editorMode": "code",
"format": "table",
"rawQuery": true,
"rawSql": "SELECT\n d.gpu_mem_megabytes,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}'",
"rawSql": "SELECT\n d.gpu_mem_megabytes,\n d.time\nFROM\n benchmarks AS b\n JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id\nWHERE\n branch = '${branch}';",
"refId": "A",
"sql": {
"columns": [
@@ -2093,6 +2220,11 @@
"type": "timeseries"
}
],
"title": "Usage metrics",
"type": "row"
}
],
"refresh": "",
"schemaVersion": 39,
"tags": [],
"templating": {
@@ -2105,7 +2237,7 @@
},
"datasource": {
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"definition": "SELECT DISTINCT branch FROM benchmarks;",
"description": "",
@@ -2125,12 +2257,12 @@
{
"current": {
"selected": false,
"text": "1728662868776",
"value": "1728662868776"
"text": "1729701492845",
"value": "1729701492845"
},
"datasource": {
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"definition": "SELECT created_at - INTERVAL '5 secs' FROM benchmarks WHERE branch = '${branch}' ORDER BY benchmark_id ASC LIMIT 1;",
"description": "",
@@ -2149,12 +2281,12 @@
{
"current": {
"selected": false,
"text": "1728663254125",
"value": "1728663254125"
"text": "1730120430069",
"value": "1730120430069"
},
"datasource": {
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"definition": "SELECT time + INTERVAL '5 secs' FROM benchmarks AS b JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id WHERE branch = '${branch}' ORDER BY b.benchmark_id DESC, d.measurement_id DESC LIMIT 1;",
"description": "",
@@ -2164,7 +2296,7 @@
"name": "EndTime",
"options": [],
"query": "SELECT time + INTERVAL '5 secs' FROM benchmarks AS b JOIN device_measurements AS d ON b.benchmark_id = d.benchmark_id WHERE branch = '${branch}' ORDER BY b.benchmark_id DESC, d.measurement_id DESC LIMIT 1;",
"refresh": 2,
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 0,
@@ -2178,7 +2310,7 @@
},
"datasource": {
"type": "grafana-postgresql-datasource",
"uid": "de0dbhs18ho1sc"
"uid": "be28nkzirtb0gd"
},
"definition": "SELECT DISTINCT gpu_name FROM benchmarks;",
"hide": 0,
@@ -2188,11 +2320,32 @@
"name": "gpu_name",
"options": [],
"query": "SELECT DISTINCT gpu_name FROM benchmarks;",
"refresh": 1,
"refresh": 2,
"regex": "",
"skipUrlSync": false,
"sort": 0,
"type": "query"
},
{
"current": {
"selected": false,
"text": "10",
"value": "10"
},
"description": "The number of commits to display, going from most recent to the nth commit.",
"hide": 0,
"label": "Last # of commits",
"name": "last_n_commits",
"options": [
{
"selected": true,
"text": "10",
"value": "10"
}
],
"query": "10",
"skipUrlSync": false,
"type": "textbox"
}
]
},
@@ -2206,6 +2359,6 @@
"timezone": "browser",
"title": "Transformers benchmarks",
"uid": "fdz33iyzln9c0a",
"version": 11,
"version": 4,
"weekStart": ""
}

View File

@@ -96,17 +96,21 @@ def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_ge
)
conn.commit()
benchmark_id = cur.fetchone()[0]
logger.info(f"running benchmark #{benchmark_id} on {gpu_name}")
metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
metrics_thread.start()
logger.info("started background thread to fetch device metrics")
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
device = "cuda"
ckpt = "meta-llama/Llama-2-7b-hf"
logger.info("downloading weights")
# This is to avoid counting download in model load time measurement
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
logger.info("loading model")
start = perf_counter()
model = AutoModelForCausalLM.from_pretrained(
ckpt, torch_dtype=torch.float16, generation_config=gen_config