* stash for now

* initial commit

* small updated

* up

* up

* works!

* nits and fixes

* don't loop too much

* finish working example

* update

* fix the small freeblocks issue

* feat: stream inputs to continuous batch

* fix: update attn from `eager` to `sdpa`

* refactor: fmt

* refactor: cleanup unnecessary code

* feat: add `update` fn to `PagedAttentionCache`

* feat: broken optimal block size computation

* fix: debugging invalid cache logic

* fix: attention mask

* refactor: use custom prompts for example

* feat: add streaming output

* fix: prefill split

refactor: add doc strings and unsound/redundant logic
fix: compute optimal blocks logic

* fix: send decoded tokens when `prefilling_split` -> `decoding`

* refactor: move logic to appropriate parent class

* fix: remove truncation as we split prefilling anyways

refactor: early return when we have enough selected requests

* feat: add paged attention forward

* push Ggraoh>

* add paged sdpa

* update

* btter mps defaults

* feat: add progress bar for `generate_batch`

* feat: add opentelemetry metrics (ttft + batch fill %age)

* feat: add tracing

* Add cuda graphs (#38059)

* draft cudagraphs addition

* nits

* styling

* update

* fix

* kinda draft of what it should look like

* fixes

* lol

* not sure why inf everywhere

* can generate but output is shit

* some fixes

* we should have a single device synch

* broken outputs but it does run

* refactor

* updates

* updates with some fixes

* fix mask causality

* another commit that casts after

* add error

* simplify example

* update

* updates

* revert llama changes

* fix merge conflicts

* fix: tracing and metrics

* my updates

* update script default values

* fix block allocation issue

* fix prefill split attnetion mask

* no bugs

* add paged eager

* fix

* update

* style

* feat: add pytorch traces

* fix

* fix

* refactor: remove pytorch profiler data

* style

* nits

* cleanup

* draft test file

* fix

* fix

* fix paged and graphs

* small renamings

* cleanups and push

* refactor: move tracing and metrics logic to utils

* refactor: trace more blocks of code

* nits

* nits

* update

* to profile or not to profile

* refactor: create new output object

* causal by default

* cleanup but generations are still off for IDK what reason

* simplifications but not running still

* this does work.

* small quality of life updates

* nits

* updaet

* fix the scheduler

* fix warning

* ol

* fully fixed

* nits

* different generation parameters

* nice

* just style

* feat: add cache memory usage

* feat: add kv cache free memory

* feat: add active/waiting count & req latency

* do the sampling

* fix: synchronize CUDA only if available and improve error handling in ContinuousBatchingManager

* fix on mps

* feat: add dashboard & histogram buckets

* perf: improve waiting reqs data structures

* attempt to compile, but we should only do it on mps AFAIK

* feat: decouple scheduling logic

* just a draft

* c;eanup and fixup

* optional

* style

* update

* update

* remove the draft documentation

* fix import as well

* update

* fix the test

* style doomed

---------

Co-authored-by: Luc Georges <luc.sydney.georges@gmail.com>
This commit is contained in:
Arthur
2025-05-22 17:43:48 +02:00
committed by GitHub
parent 73286d8e29
commit 211f2b0875
21 changed files with 3467 additions and 11 deletions

View File

@@ -0,0 +1,4 @@
# Metrics Monitoring
## Continuous Batching Metrics in Transformers

View File

@@ -0,0 +1,974 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"target": {
"limit": 100,
"matchAny": false,
"tags": [],
"type": "dashboard"
},
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"id": 2,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"description": "Memory usage of the PagedAttentionCache",
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"max": 10737418240,
"min": 0,
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "yellow",
"value": 5368709120
},
{
"color": "red",
"value": 8589934592
}
]
},
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 0,
"y": 0
},
"id": 2,
"options": {
"minVizHeight": 75,
"minVizWidth": 75,
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true,
"sizing": "auto"
},
"pluginVersion": "12.0.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "kv_cache_memory_bytes",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "KV Cache Memory Usage",
"transparent": true,
"type": "gauge"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "dark-blue"
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 6,
"y": 0
},
"id": 13,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.0.0",
"targets": [
{
"disableTextWrap": false,
"editorMode": "builder",
"expr": "active_requests_count",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "Active Requests",
"transparent": true,
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "dark-orange"
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 12,
"y": 0
},
"id": 14,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.0.0",
"targets": [
{
"disableTextWrap": false,
"editorMode": "builder",
"expr": "waiting_requests_count",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "Waiting Requests",
"transparent": true,
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"description": "Ratio of decode tokens to prefill tokens in a batch",
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "blue"
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 18,
"y": 0
},
"id": 6,
"options": {
"colorMode": "value",
"graphMode": "none",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.0.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "decode_prefill_ratio",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "Decode/Prefill Ratio",
"transparent": true,
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 8
},
"id": 10,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.0.0",
"targets": [
{
"editorMode": "code",
"expr": "rate(decode_tokens_processed_total[$__rate_interval])",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Decode tokens throupught tok/s",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 8
},
"id": 11,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.0.0",
"targets": [
{
"editorMode": "code",
"expr": "rate(prefill_tokens_processed_total[$__rate_interval])",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Prefill rate tok/s",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 9,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.0.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.95, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
"legendFormat": "p95",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
"hide": false,
"instant": false,
"legendFormat": "p99",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
"hide": false,
"instant": false,
"legendFormat": "p50",
"range": true,
"refId": "C"
}
],
"title": "Batch fill percentage percentiles",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"description": "KV Cache Memory Usage Over Time",
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
},
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 4,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.0.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "kv_cache_memory_bytes",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "Used memory",
"range": true,
"refId": "A",
"useBackend": false
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "kv_cache_free_memory_bytes",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": true,
"instant": false,
"legendFormat": "free memory",
"range": true,
"refId": "B",
"useBackend": false
}
],
"title": "KV Cache Memory Usage Over Time",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 24
},
"id": 8,
"options": {
"displayMode": "gradient",
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"maxVizHeight": 300,
"minVizHeight": 10,
"minVizWidth": 0,
"namePlacement": "auto",
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true,
"sizing": "auto",
"valueMode": "color"
},
"pluginVersion": "12.0.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "histogram_quantile(0.95, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
"fullMetaSearch": false,
"includeNullMetadata": true,
"legendFormat": "p95",
"range": true,
"refId": "A",
"useBackend": false
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "histogram_quantile(0.5, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": true,
"legendFormat": "p50",
"range": true,
"refId": "B",
"useBackend": false
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "histogram_quantile(0.99, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": false,
"instant": false,
"legendFormat": "p99",
"range": true,
"refId": "C",
"useBackend": false
}
],
"title": "Time to First Token (TTFT)",
"type": "bargauge"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green"
},
{
"color": "red",
"value": 80
}
]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 24
},
"id": 12,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.0.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"editorMode": "code",
"expr": "histogram_quantile(0.95, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
"hide": false,
"instant": false,
"legendFormat": "p95",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
"hide": false,
"instant": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Request latency percentiles",
"type": "timeseries"
}
],
"preload": false,
"refresh": "5s",
"schemaVersion": 41,
"tags": [],
"templating": {
"list": []
},
"time": {
"from": "now-15m",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Transformers Continuous Batching Metrics",
"uid": "Lw6CTvVSz",
"version": 5
}

View File

@@ -0,0 +1,55 @@
services:
memcached:
image: memcached:1.6.29
container_name: memcached
ports:
- "11211:11211"
environment:
- MEMCACHED_MAX_MEMORY=64m # Set the maximum memory usage
- MEMCACHED_THREADS=4 # Number of threads to use
prometheus:
image: prom/prometheus:latest
command:
- "--config.file=/etc/prometheus/prometheus.yml"
- --web.enable-otlp-receiver # Enable OTLP receiver
- --web.enable-remote-write-receiver
- --enable-feature=exemplar-storage
- --enable-feature=native-histograms
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
ports:
- "9090:9090"
tempo:
image: grafana/tempo:latest
command: [ "-config.file=/etc/tempo.yaml" ]
volumes:
- ./tempo.yaml:/etc/tempo.yaml
ports:
- "14268:14268" # jaeger ingest
- "3200:3200" # tempo
- "9095:9095" # tempo grpc
- "4317:4317" # otlp grpc
- "4318:4318" # otlp http
- "9411:9411" # zipkin
depends_on:
- memcached
grafana:
image: grafana/grafana:latest
volumes:
- ./continuous-batching-dashboard.json:/etc/grafana/provisioning/dashboards/continuous-batching-dashboard.json
- ./grafana-dashboard.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboard.yaml
- ./grafana-datasources.yaml:/etc/grafana/provisioning/datasources/datasources.yaml
environment:
- GF_AUTH_ANONYMOUS_ENABLED=true
- GF_AUTH_ANONYMOUS_ORG_ROLE=Admin
- GF_AUTH_DISABLE_LOGIN_FORM=true
- GF_FEATURE_TOGGLES_ENABLE=traceqlEditor metricsSummary
- GF_INSTALL_PLUGINS=https://storage.googleapis.com/integration-artifacts/grafana-exploretraces-app/grafana-exploretraces-app-latest.zip;grafana-traces-app
ports:
- "3000:3000"
depends_on:
- prometheus
- tempo

View File

@@ -0,0 +1,11 @@
apiVersion: 1
providers:
- name: 'Transformers Dashboards'
orgId: 1
folder: 'Transformers'
type: file
disableDeletion: false
editable: true
options:
path: /etc/grafana/provisioning/dashboards

View File

@@ -0,0 +1,14 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
access: proxy
url: http://prometheus:9090
isDefault: true
- name: Tempo
type: tempo
access: proxy
url: http://tempo:3200
uid: tempo

View File

@@ -0,0 +1,48 @@
# Example usage of the trace and attach_tracer decorators
from transformers.utils.metrics import attach_tracer, traced
@attach_tracer()
class ExampleClass:
def __init__(self, name):
# The attach_tracer decorator has already created self.tracer for us
self.name = name
@traced # This method will use the tracer from the class instance
def process_data(self, data):
# This method is traced and can use self.tracer
return f"Processed {data} with {self.name}"
@traced(span_name="custom_operation") # With custom span name
def special_operation(self, value):
# Also traced, with a custom span name
return value * 2
@traced(
additional_attributes=[
("name", "object.name", lambda x: x.upper()), # Using a transform function
("name", "object.fixed_value", "static_value"), # Using a fixed value
]
)
def operation_with_attributes(self):
# This will add the specified attributes to the span
return "Operation completed"
# For functions without a class, the traced decorator still works
@traced
def standalone_function(arg1, arg2):
# For functions, a tracer is created based on the module name
return arg1 + arg2
# Usage:
if __name__ == "__main__":
# With OpenTelemetry configured, these will produce traces
example = ExampleClass("test_object")
example.process_data("sample")
example.special_operation(42)
example.operation_with_attributes()
result = standalone_function(1, 2)

View File

@@ -0,0 +1,3 @@
global:
scrape_interval: 15s

View File

@@ -0,0 +1,90 @@
stream_over_http_enabled: true
server:
http_listen_port: 3200
log_level: info
cache:
background:
writeback_goroutines: 5
caches:
- roles:
- frontend-search
memcached:
addresses: dns+memcached:11211
query_frontend:
search:
duration_slo: 5s
throughput_bytes_slo: 1.073741824e+09
metadata_slo:
duration_slo: 5s
throughput_bytes_slo: 1.073741824e+09
trace_by_id:
duration_slo: 100ms
metrics:
max_duration: 200h # maximum duration of a metrics query, increase for local setups
query_backend_after: 5m
duration_slo: 5s
throughput_bytes_slo: 1.073741824e+09
distributor:
receivers: # this configuration will listen on all ports and protocols that tempo is capable of.
jaeger: # the receives all come from the OpenTelemetry collector. more configuration information can
protocols: # be found there: https://github.com/open-telemetry/opentelemetry-collector/tree/main/receiver
thrift_http: #
endpoint: "tempo:14268" # for a production deployment you should only enable the receivers you need!
grpc:
endpoint: "tempo:14250"
thrift_binary:
endpoint: "tempo:6832"
thrift_compact:
endpoint: "tempo:6831"
zipkin:
endpoint: "tempo:9411"
otlp:
protocols:
grpc:
endpoint: "tempo:4317"
http:
endpoint: "tempo:4318"
opencensus:
endpoint: "tempo:55678"
ingester:
max_block_duration: 5m # cut the headblock when this much time passes. this is being set for demo purposes and should probably be left alone normally
compactor:
compaction:
block_retention: 720h # overall Tempo trace retention. set for demo purposes
metrics_generator:
registry:
external_labels:
source: tempo
cluster: docker-compose
storage:
path: /var/tempo/generator/wal
remote_write:
- url: http://prometheus:9090/api/v1/write
send_exemplars: true
traces_storage:
path: /var/tempo/generator/traces
processor:
local_blocks:
filter_server_spans: false
flush_to_storage: true
storage:
trace:
backend: local # backend configuration to use
wal:
path: /var/tempo/wal # where to store the wal locally
local:
path: /var/tempo/blocks
overrides:
defaults:
metrics_generator:
processors: [service-graphs, span-metrics, local-blocks] # enables metrics generator
generate_native_histograms: both

View File

@@ -0,0 +1,109 @@
import time
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
generation_config = GenerationConfig(
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=2048,
block_size=128,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---")
def tokenize_function(examples):
return tokenizer(examples["question"])
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
start_time_simple = time.time()
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
)
end_time_simple = time.time()
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
except Exception as e:
print(f"Decoding failed for request {request}: {e}")
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
print(f"{request} Output: {output_text}")
else:
print("", end="\r\r\r\r")
print("-" * 20)
print("--- Finished CB Generation Example ---\n\n")
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512)
# simple_batch_inputs = list(tokenized_test_prompts["input_ids"])
# def tokenize_function(examples):
# # Truncate to avoid overly long prompts exceeding max context length
# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512)
# tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
# model.config.attn_implementation = "sdpa"
# start_time_simple = time.time()
# batch_size = 64
# full_outputs = []
# from tqdm import tqdm
# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)):
# outputs = model.generate(
# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device),
# generation_config=GenerationConfig(
# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
# ),
# )
# full_outputs.extend(outputs.tolist())
# end_time_simple = time.time()
# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds")
# print("\nResults from simple generate_batch:")
# for i, request in enumerate(full_outputs):
# output_text = tokenizer.decode(request, skip_special_tokens=False)
# print("-" * 20)
# print(f" Output: {output_text}")
# print("-" * 20)
# print("--- Finished Simple Batch Generation Example ---\n\n")

View File

@@ -201,6 +201,9 @@ _deps = [
"pytest-rich",
"libcst",
"rich",
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-sdk",
]
@@ -435,6 +438,9 @@ extras["torchhub"] = deps_list(
extras["benchmark"] = deps_list("optimum-benchmark")
# OpenTelemetry dependencies for metrics collection in continuous batching
extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk")
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads

View File

@@ -103,4 +103,7 @@ deps = {
"pytest-rich": "pytest-rich",
"libcst": "libcst",
"rich": "rich",
"opentelemetry-api": "opentelemetry-api",
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
"opentelemetry-sdk": "opentelemetry-sdk",
}

View File

@@ -97,6 +97,9 @@ else:
"validate_stopping_criteria",
"StopStringCriteria",
]
_import_structure["continuous_batching"] = [
"ContinuousMixin",
]
_import_structure["utils"] = [
"GenerationMixin",
"GreedySearchEncoderDecoderOutput",
@@ -213,6 +216,7 @@ if TYPE_CHECKING:
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
)
from .continuous_batching import ContinuousMixin
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,

File diff suppressed because it is too large Load Diff

View File

@@ -79,6 +79,7 @@ from .configuration_utils import (
GenerationConfig,
GenerationMode,
)
from .continuous_batching import ContinuousMixin
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
@@ -352,7 +353,7 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
class GenerationMixin:
class GenerationMixin(ContinuousMixin):
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
@@ -1099,10 +1100,10 @@ class GenerationMixin:
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList],
input_ids_seq_length: Optional[int] = None,
encoder_input_ids: torch.LongTensor = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
device: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
@@ -1114,6 +1115,8 @@ class GenerationMixin:
"""
# instantiate processors list
processors = LogitsProcessorList()
if logits_processor is None:
logits_processor = []
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
processors.append(
@@ -1183,7 +1186,7 @@ class GenerationMixin:
)
if (
generation_config.min_length is not None
and generation_config._eos_token_tensor is not None
and getattr(generation_config, "_eos_token_tensor", None) is not None
and generation_config.min_length > 0
):
processors.append(
@@ -1195,7 +1198,7 @@ class GenerationMixin:
)
if (
generation_config.min_new_tokens is not None
and generation_config._eos_token_tensor is not None
and getattr(generation_config, "_eos_token_tensor", None) is not None
and generation_config.min_new_tokens > 0
):
processors.append(

View File

@@ -0,0 +1,45 @@
from typing import Optional
import torch
from torch import nn
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_paged_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
cache = kwargs.pop("cache", None)
if cache is not None:
key, value = cache.update(key, value, module.layer_idx, **kwargs)
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights

View File

@@ -0,0 +1,64 @@
import torch
from ..generation.continuous_batching import PagedAttentionCache
from ..utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
def paged_attention_forward(
module: torch.nn.Module,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor = None,
cache: PagedAttentionCache = None,
cumulative_seqlens_q=None,
cumulative_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
This function handles the cache updates and performs the attention computation
using the flash_attn_varlen_func for efficient processing.
Args:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0),
k.transpose(1, 2).squeeze(0),
v.transpose(1, 2).squeeze(0),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32),
max_seqlen_q,
max_seqlen_k,
softmax_scale=module.scaling,
causal=True, # kind of a must, it automatically aligns the mask for q < k
window_size=(-1, -1), # -1 means infinite context window
# block_table=block_tables, -> torch.Tensor
# **kwargs,
)
return attn_output, None

View File

@@ -0,0 +1,51 @@
from typing import Optional, Tuple
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def sdpa_attention_paged_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
cache = kwargs.pop("cache", None)
if cache is not None:
key, value = cache.update(key, value, module.layer_idx, **kwargs)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None

View File

@@ -427,9 +427,9 @@ class FlashAttentionKwargs(TypedDict, total=False):
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
@@ -437,7 +437,7 @@ class FlashAttentionKwargs(TypedDict, total=False):
Maximum sequence length for key state.
"""
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]

View File

@@ -57,9 +57,12 @@ from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.eager_paged import eager_paged_attention_forward
from .integrations.flash_attention import flash_attention_forward
from .integrations.flash_paged import paged_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
@@ -6089,7 +6092,10 @@ class AttentionInterface(GeneralInterface):
_global_mapping = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,
"sdpa": sdpa_attention_forward,
"sdpa_paged": sdpa_attention_paged_forward,
"eager_paged": eager_paged_attention_forward,
}

View File

@@ -0,0 +1,434 @@
import functools
import logging
import time
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
class RequestStatus(Enum):
"""Status of a generation request through its lifecycle."""
PENDING = "pending"
PREFILLING = "prefilling"
PREFILLING_SPLIT = "prefilling_split"
SPLIT_PENDING_REMAINDER = "split_pending_remainder"
DECODING = "decoding"
FINISHED = "finished"
FAILED = "failed"
try:
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import Status, StatusCode, get_tracer
resource = Resource.create({"service.name": "transformers"})
metrics_exporter = PeriodicExportingMetricReader(OTLPMetricExporter(), export_interval_millis=1000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter])
metrics.set_meter_provider(meter_provider)
trace_exporter = OTLPSpanExporter()
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider)
_has_opentelemetry = True
except ImportError:
_has_opentelemetry = False
def attach_tracer(tracer_name_template=None):
"""
Decorator that attaches a tracer to a class.
This decorator should be applied to classes that need OpenTelemetry tracing.
It adds a tracer attribute to the class instance that can be used by the traced decorator.
Args:
tracer_name_template: Optional template string for the tracer name.
If provided, it should contain {module} which will be replaced with the class's full module path
and {class_name} for the class name.
If None, a default naming scheme will be used where:
- If the module already starts with "transformers.", it will use that directly
- Otherwise, it will prepend "transformers." to the module name
Returns:
Class decorator function
"""
if not _has_opentelemetry:
return lambda cls: cls
def decorator(cls):
original_init = cls.__init__
@functools.wraps(original_init)
def init_with_tracer(self, *args, **kwargs):
original_init(self, *args, **kwargs)
module_name = cls.__module__
class_name = cls.__qualname__
if tracer_name_template is None:
if module_name.startswith("transformers."):
tracer_name = f"{module_name}.{class_name}"
else:
tracer_name = f"transformers.{module_name}.{class_name}"
else:
tracer_name = tracer_name_template.format(module=module_name, class_name=class_name)
self.tracer = get_tracer(tracer_name)
cls.__init__ = init_with_tracer
return cls
return decorator
def traced(
func=None,
*,
span_name=None,
standalone=False,
additional_attributes: Optional[List[Tuple[str, str, Union[Any, Callable[[Any], Any]]]]] = None,
):
"""
Decorator to trace function calls with OpenTelemetry.
Can be used as @traced or @traced(span_name="custom_name")
Args:
func: The function to trace
span_name: Optional custom name for the span (defaults to function name)
standalone: If True, creates a parentless span
additional_attributes: Optional list of additional attributes to set on the span.
Each item is a tuple of (instance_attribute_name, span_attribute_key, value_or_transform_function)
where:
- instance_attribute_name: Name of the attribute to get from the class instance
- span_attribute_key: Key to use when setting the attribute on the span
- value_or_transform_function: Either a raw value to use directly, or a function to transform
the attribute value before setting it on the span
Returns:
Decorated function with tracing
"""
def decorator(func):
if not _has_opentelemetry:
return func
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
instance = args[0] if args and (hasattr(func, "__self__") and func.__self__ is not None) else None
is_method = instance is not None
if is_method and hasattr(instance, "tracer"):
tracer = instance.tracer
else:
tracer = get_tracer(f"transformers.{func.__module__}.{func.__name__}")
name = span_name or func.__name__
span_fn = tracer.start_span if standalone else tracer.start_as_current_span
with span_fn(name) as span:
span.set_attribute("function.name", func.__name__)
span.set_attribute("function.module", func.__module__)
span.set_attribute("function.is_method", is_method)
if args:
for i, arg in enumerate(args):
if isinstance(arg, (str, int, float, bool)) or arg is None:
span.set_attribute(f"args.{i}", str(arg))
else:
span.set_attribute(f"args.{i}", str(type(arg)))
if kwargs:
for key, value in kwargs.items():
if isinstance(value, (str, int, float, bool)) or value is None:
span.set_attribute(f"kwargs.{key}", str(value))
else:
span.set_attribute(f"kwargs.{key}", str(type(value)))
if additional_attributes and is_method:
for attr_config in additional_attributes:
instance_attribute_name, span_attribute_key, value_or_transform_function = attr_config
if hasattr(instance, instance_attribute_name):
attribute_value = getattr(instance, instance_attribute_name)
if callable(value_or_transform_function):
transformed_value = value_or_transform_function(attribute_value)
else:
transformed_value = value_or_transform_function
span.set_attribute(span_attribute_key, transformed_value)
try:
result = func(*args, **kwargs)
return result
except Exception as e:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(e)
raise
return wrapper
if func is None:
return decorator
return decorator(func)
logger = logging.getLogger(__name__)
@attach_tracer()
class ContinuousBatchProcessorMetrics:
"""Metrics collection for ContinuousBatchProcessor."""
def __init__(self, max_batch_tokens: int):
"""Initialize metrics for continuous batch processor.
Args:
max_batch_tokens: Maximum number of tokens in a batch
"""
self.max_batch_tokens = max_batch_tokens
self._setup_metrics()
def _setup_metrics(self):
"""Initialize OpenTelemetry metrics and tracing if the library is available."""
if not _has_opentelemetry:
logger.info("OpenTelemetry is not installed. Metrics and tracing will not be recorded.")
return
self.meter = metrics.get_meter("transformers.generation.continuous_batch_processor")
# Define appropriate buckets for TTFT (typically ranges from ~50ms to several seconds)
ttft_buckets = [10, 25, 50, 75, 100, 150, 200, 300, 500, 750, 1000, 2000, 5000, 10000]
self.ttft_histogram = self.meter.create_histogram(
name="ttft_milliseconds",
description="Time to first token in milliseconds",
unit="ms",
explicit_bucket_boundaries_advisory=ttft_buckets,
)
self.active_requests_gauge = self.meter.create_gauge(
name="active_requests_count",
description="Number of active requests currently being processed",
unit="requests",
)
self.waiting_requests_gauge = self.meter.create_gauge(
name="waiting_requests_count",
description="Number of requests waiting to be processed",
unit="requests",
)
# Define appropriate buckets for request latency (similar to TTFT but with higher upper bounds)
latency_buckets = [50, 100, 250, 500, 1000, 2000, 5000, 10000, 20000, 30000, 60000]
self.request_latency_histogram = self.meter.create_histogram(
name="request_latency_milliseconds",
description="End-to-end latency for completed requests in milliseconds",
unit="ms",
explicit_bucket_boundaries_advisory=latency_buckets,
)
self.decode_prefill_ratio_gauge = self.meter.create_gauge(
name="decode_prefill_ratio",
description="Ratio of decode tokens to prefill tokens in a batch",
unit="ratio",
)
self.prefill_tokens_counter = self.meter.create_counter(
name="prefill_tokens_processed",
description="Number of prefill tokens processed",
unit="tokens",
)
self.decode_tokens_counter = self.meter.create_counter(
name="decode_tokens_processed",
description="Number of decode tokens processed",
unit="tokens",
)
# Define appropriate buckets for batch fill percentage (0-100%)
batch_fill_buckets = [5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 98, 100]
self.batch_fill_percentage_histogram = self.meter.create_histogram(
name="batch_fill_percentage",
description="Percentage of max_batch_tokens utilized in each batch",
unit="percent",
explicit_bucket_boundaries_advisory=batch_fill_buckets,
)
self.kv_cache_free_memory_gauge = self.meter.create_gauge(
name="kv_cache_free_memory_bytes",
description="Free memory of the PagedAttentionCache in bytes",
unit="bytes",
)
self.kv_cache_memory_gauge = self.meter.create_gauge(
name="kv_cache_memory_bytes",
description="Memory usage of the PagedAttentionCache in bytes",
unit="bytes",
)
@traced
def record_ttft_metric(self, created_time: float, request_id: str) -> None:
"""Record Time to First Token (TTFT).
Args:
created_time: The time the request was created
request_id: The ID of the request
"""
if not _has_opentelemetry:
return
ttft_ms = (time.time() - created_time) * 1000.0
try:
self.ttft_histogram.record(ttft_ms)
logger.debug(f"Recorded TTFT for request {request_id}: {ttft_ms:.2f}ms")
except Exception as e:
logger.warning(f"Failed to record TTFT metric: {e}")
@traced
def record_batch_metrics(self, requests_in_batch: List) -> None:
"""Record metrics about the batch composition including decode/prefill ratio and batch fill percentage.
Args:
requests_in_batch: List of request states in the current batch
"""
if not _has_opentelemetry or not requests_in_batch:
return
decode_tokens = 0
prefill_tokens = 0
for state in requests_in_batch:
if state.status == RequestStatus.DECODING:
decode_tokens += 1
elif state.status in [RequestStatus.PREFILLING, RequestStatus.PREFILLING_SPLIT]:
prefill_tokens += len(state.prompt_ids)
total_batch_tokens = decode_tokens + prefill_tokens
try:
if prefill_tokens > 0:
self.prefill_tokens_counter.add(prefill_tokens)
if decode_tokens > 0:
self.decode_tokens_counter.add(decode_tokens)
if prefill_tokens > 0:
ratio = decode_tokens / prefill_tokens
self.decode_prefill_ratio_gauge.set(ratio)
fill_percentage = (total_batch_tokens / self.max_batch_tokens) * 100.0
self.batch_fill_percentage_histogram.record(fill_percentage)
logger.debug(
f"Batch metrics: {decode_tokens} decode tokens, {prefill_tokens} prefill tokens, "
f"batch fill: {fill_percentage:.2f}% ({total_batch_tokens}/{self.max_batch_tokens})"
)
except Exception as e:
logger.warning(f"Failed to record batch metrics: {e}")
@traced
def record_kv_cache_memory_metrics(self, cache) -> None:
"""Record memory usage of the PagedAttentionCache without GPU synchronization.
This calculates the theoretical memory usage based on cache configuration
and the number of blocks currently in use.
Args:
cache: The PagedAttentionCache object to measure
"""
if not _has_opentelemetry:
return
try:
# Calculate memory usage based on cache configuration
num_used_blocks = cache.num_blocks - len(cache._free_blocks)
num_layers = len(cache.key_cache)
# Each used block stores key and value states
# Each with shape: (num_kv_heads, block_size, head_dim)
bytes_per_parameter = 2 if cache.dtype in [torch.float16, torch.bfloat16] else 4 # Size in bytes
# Total bytes = num_layers * num_used_blocks * block_size *
# num_kv_heads * head_dim * 2 (both K and V) * bytes_per_parameter
memory_bytes = (
num_layers
* num_used_blocks
* cache.block_size
* cache.num_key_value_heads
* cache.head_dim
* 2 # For both key and value caches
* bytes_per_parameter
)
free_memory_bytes = (
num_layers
* len(cache._free_blocks)
* cache.block_size
* cache.num_key_value_heads
* cache.head_dim
* 2 # For both key and value caches
* bytes_per_parameter
)
self.kv_cache_memory_gauge.set(memory_bytes)
self.kv_cache_free_memory_gauge.set(free_memory_bytes)
logger.debug(
f"KV Cache memory: {memory_bytes / (1024 * 1024):.2f}MB, "
f"Used blocks: {num_used_blocks}/{cache.num_blocks} "
f"({num_used_blocks / cache.num_blocks * 100:.1f}%)"
)
except Exception as e:
logger.warning(f"Failed to record KV cache memory metrics: {e}")
@traced
def record_queue_metrics(self, active_requests: int, waiting_requests: int) -> None:
"""Record metrics about active and waiting requests.
Args:
active_requests: Number of active requests
waiting_requests: Number of waiting requests
"""
if not _has_opentelemetry:
return
try:
self.active_requests_gauge.set(active_requests)
self.waiting_requests_gauge.set(waiting_requests)
logger.debug(f"Queue metrics: {active_requests} active requests, {waiting_requests} waiting requests")
except Exception as e:
logger.warning(f"Failed to record queue metrics: {e}")
@traced
def record_request_completion(self, created_time: float, request_id: str) -> None:
"""Record metrics about a completed request.
Args:
created_time: The time the request was created
request_id: The ID of the request
"""
if not _has_opentelemetry:
return
latency_ms = (time.time() - created_time) * 1000.0
try:
self.request_latency_histogram.record(latency_ms)
logger.debug(f"Recorded request completion for {request_id}: {latency_ms:.2f}ms")
except Exception as e:
logger.warning(f"Failed to record request completion metric: {e}")

View File

@@ -0,0 +1,86 @@
import time
import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_flash_attn, require_torch_gpu, run_slow
_TEST_PROMPTS = [
"A man is a walking his dog down the street, and a the turn he sees",
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
"Please fill in the form to",
"For safety reasons, the train is stopped in the middle of the",
]
_EXPECTED_OUTPUTS = [
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and",
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:",
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
]
@run_slow
@require_torch_gpu
@require_flash_attn
class TestBatchGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", torch_dtype="bfloat16", device_map="auto"
).eval()
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
if cls.tokenizer.pad_token is None:
cls.tokenizer.pad_token = cls.tokenizer.eos_token
cls.model.config.pad_token_id = cls.model.config.eos_token_id
cls.model.use_cache = False
@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
]
)
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
self.model.config.attn_implementation = attn_impl
generation_config = GenerationConfig(
max_new_tokens=50,
top_k=0,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=num_blocks,
block_size=block_size,
max_batch_tokens=max_batch_tokens,
)
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
batch_inputs = list(tokenized["input_ids"])
start = time.time()
batch_outputs = self.model.generate_batch(
inputs=batch_inputs,
generation_config=generation_config,
)
end = time.time()
print(
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
)
for i, req_id in enumerate(batch_outputs):
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
expected = _EXPECTED_OUTPUTS[i].strip()
self.assertTrue(
generated.startswith(expected),
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
)