commit for merge
This commit is contained in:
222
sc/hf_api.py
Normal file
222
sc/hf_api.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
HTTP API for HuggingFace causal LM inference (answer + logits).
|
||||||
|
|
||||||
|
This module exposes a small FastAPI service that:
|
||||||
|
- Loads a local HuggingFace CausalLM once at startup
|
||||||
|
- Accepts chat-style messages via HTTP
|
||||||
|
- Returns generated text
|
||||||
|
- Returns *logits-derived* information in a practical size:
|
||||||
|
- top-k logprobs for each generated token (recommended)
|
||||||
|
- optional prompt logits top-k for the final prompt position
|
||||||
|
|
||||||
|
Why not return full logits?
|
||||||
|
Full logits are extremely large (seq_len x vocab_size) and will quickly
|
||||||
|
overwhelm network/memory. This API defaults to returning top-k logprobs.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
MODEL_DIR=/path/to/model \\
|
||||||
|
uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
Example:
|
||||||
|
curl -X POST http://localhost:8000/generate \\
|
||||||
|
-H 'Content-Type: application/json' \\
|
||||||
|
-d '{"messages":[{"role":"user","content":"Hello!"}],"max_new_tokens":64,"top_k":10}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODEL_DIR = os.environ.get(
|
||||||
|
"MODEL_DIR", "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI(title="HF LLM API", version="0.1.0")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["system", "user", "assistant"]
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
max_new_tokens: int = Field(default=128, ge=1, le=1024)
|
||||||
|
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
|
||||||
|
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
do_sample: bool = False
|
||||||
|
top_k: int = Field(default=20, ge=1, le=200)
|
||||||
|
|
||||||
|
# If True, also returns prompt last-position top-k logits (not full matrix)
|
||||||
|
include_prompt_topk: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTopK(BaseModel):
|
||||||
|
token_id: int
|
||||||
|
token: str
|
||||||
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedStep(BaseModel):
|
||||||
|
token_id: int
|
||||||
|
token: str
|
||||||
|
logprob: float
|
||||||
|
topk: List[TokenTopK]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTopK(BaseModel):
|
||||||
|
position: int
|
||||||
|
topk: List[TokenTopK]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponse(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
generated_text: str
|
||||||
|
generated_token_ids: List[int]
|
||||||
|
steps: List[GeneratedStep]
|
||||||
|
prompt_topk: Optional[PromptTopK] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_device() -> str:
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model_and_tokenizer(model_dir: str):
|
||||||
|
device = _get_device()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||||
|
device_map="auto" if device == "cuda" else None,
|
||||||
|
).eval()
|
||||||
|
return tokenizer, model, device
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def _startup_load():
|
||||||
|
# Load once; shared across requests.
|
||||||
|
global tokenizer, model, device
|
||||||
|
tokenizer, model, device = _load_model_and_tokenizer(DEFAULT_MODEL_DIR)
|
||||||
|
app.state.model_dir = DEFAULT_MODEL_DIR
|
||||||
|
app.state.device = device
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"model_dir": getattr(app.state, "model_dir", None),
|
||||||
|
"device": getattr(app.state, "device", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_chat_template(messages: List[ChatMessage]) -> str:
|
||||||
|
# Convert pydantic objects to plain dicts compatible with HF template.
|
||||||
|
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||||
|
try:
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
msg_dicts,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback: naive concatenation
|
||||||
|
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in msg_dicts]) + "\nassistant:"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _topk_from_logits(logits_1d: torch.Tensor, top_k: int) -> List[TokenTopK]:
|
||||||
|
# logits_1d: (vocab,)
|
||||||
|
top_vals, top_ids = torch.topk(logits_1d, k=top_k)
|
||||||
|
# Convert to logprobs for interpretability
|
||||||
|
logprobs = torch.log_softmax(logits_1d, dim=-1)
|
||||||
|
out: List[TokenTopK] = []
|
||||||
|
for tid in top_ids.tolist():
|
||||||
|
tok = tokenizer.decode([tid])
|
||||||
|
out.append(
|
||||||
|
TokenTopK(
|
||||||
|
token_id=int(tid),
|
||||||
|
token=tok,
|
||||||
|
logprob=float(logprobs[tid].detach().cpu().item()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Sort in descending logprob (topk preserves order, but be explicit)
|
||||||
|
out.sort(key=lambda x: x.logprob, reverse=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate", response_model=GenerateResponse)
|
||||||
|
def generate(req: GenerateRequest) -> GenerateResponse:
|
||||||
|
if not hasattr(app.state, "model_dir"):
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||||
|
|
||||||
|
prompt = _apply_chat_template(req.messages)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
if device == "cuda":
|
||||||
|
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Use generate() so we can get per-step scores (logits).
|
||||||
|
# output.scores is a list with length = generated_tokens
|
||||||
|
# each element shape: (batch, vocab)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=req.max_new_tokens,
|
||||||
|
do_sample=req.do_sample if req.temperature > 0 else False,
|
||||||
|
temperature=req.temperature if req.temperature > 0 else 1.0,
|
||||||
|
top_p=req.top_p,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generated token ids include prompt + new tokens
|
||||||
|
seq = out.sequences[0]
|
||||||
|
prompt_len = int(inputs["input_ids"].shape[1])
|
||||||
|
gen_token_ids = seq[prompt_len:].tolist()
|
||||||
|
generated_text = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Build per-step top-k + chosen token logprob
|
||||||
|
steps: List[GeneratedStep] = []
|
||||||
|
if out.scores is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Model did not return scores")
|
||||||
|
|
||||||
|
for step_idx, step_logits in enumerate(out.scores):
|
||||||
|
# step_logits: (1, vocab)
|
||||||
|
step_logits_1d = step_logits[0]
|
||||||
|
chosen_id = int(gen_token_ids[step_idx]) if step_idx < len(gen_token_ids) else None
|
||||||
|
logprobs_1d = torch.log_softmax(step_logits_1d, dim=-1)
|
||||||
|
chosen_logprob = float(logprobs_1d[chosen_id].detach().cpu().item()) if chosen_id is not None else float("nan")
|
||||||
|
steps.append(
|
||||||
|
GeneratedStep(
|
||||||
|
token_id=chosen_id,
|
||||||
|
token=tokenizer.decode([chosen_id]),
|
||||||
|
logprob=chosen_logprob,
|
||||||
|
topk=_topk_from_logits(step_logits_1d, req.top_k),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_topk: Optional[PromptTopK] = None
|
||||||
|
if req.include_prompt_topk:
|
||||||
|
with torch.no_grad():
|
||||||
|
forward = model(**inputs)
|
||||||
|
# forward.logits: (1, seq_len, vocab)
|
||||||
|
last_pos = int(forward.logits.shape[1] - 1)
|
||||||
|
last_logits = forward.logits[0, -1, :]
|
||||||
|
prompt_topk = PromptTopK(position=last_pos, topk=_topk_from_logits(last_logits, req.top_k))
|
||||||
|
|
||||||
|
return GenerateResponse(
|
||||||
|
prompt=prompt,
|
||||||
|
generated_text=generated_text,
|
||||||
|
generated_token_ids=[int(t) for t in gen_token_ids],
|
||||||
|
steps=steps,
|
||||||
|
prompt_topk=prompt_topk,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,23 +1,31 @@
|
|||||||
import os
|
import torch
|
||||||
import sys
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||||
if _PROJECT_ROOT not in sys.path:
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import asyncio
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||||
from fire import Fire
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
from sc.core.model import load_models
|
model_dir,
|
||||||
from langchain_core.messages import HumanMessage
|
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||||
|
device_map="auto" if device == "cuda" else None,
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
|
||||||
|
]
|
||||||
|
|
||||||
async def main():
|
# Convert chat messages -> a single prompt string using the model's chat template
|
||||||
models = load_models(["ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b"])
|
prompt = tokenizer.apply_chat_template(
|
||||||
prompt = "nice to meet you."
|
messages,
|
||||||
response = await models.invoke([HumanMessage(content=prompt)], logprobs=True, top_logprobs=20)
|
tokenize=False,
|
||||||
print("Response:")
|
add_generation_prompt=True,
|
||||||
print(response)
|
)
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
with torch.no_grad():
|
||||||
Fire(main)
|
out = model(**inputs)
|
||||||
|
|
||||||
|
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
|
||||||
|
print("logits shape:", logits.shape)
|
||||||
124
sc/transformers_serve_client_example.py
Normal file
124
sc/transformers_serve_client_example.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Call transformers serve (chat) and our HF API (logprobs) from Python with requests.
|
||||||
|
|
||||||
|
Summary:
|
||||||
|
--------
|
||||||
|
- transformers serve (transformers chat localhost:8000 --model-name-or-path ...):
|
||||||
|
- Exposes OpenAI-compatible endpoints: /v1/chat/completions, /v1/responses, /v1/models.
|
||||||
|
- It does NOT return logits or logprobs; the response chunks have "logprobs": null.
|
||||||
|
- You can still use it for chat from Python via requests (see chat_with_transformers_serve).
|
||||||
|
|
||||||
|
- For answer + logprobs (top-k per token):
|
||||||
|
- Use our custom API in sc/hf_api.py:
|
||||||
|
MODEL_DIR=/path/to/model uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||||
|
- Then call POST /generate with requests (see get_logprobs_via_hf_api).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
------
|
||||||
|
# Chat only (transformers serve on 8000):
|
||||||
|
python -c "
|
||||||
|
from sc.transformers_serve_client_example import chat_with_transformers_serve
|
||||||
|
print(chat_with_transformers_serve('Hello!'))
|
||||||
|
"
|
||||||
|
|
||||||
|
# Answer + logprobs (sc/hf_api on 8000):
|
||||||
|
python -c "
|
||||||
|
from sc.transformers_serve_client_example import get_logprobs_via_hf_api
|
||||||
|
r = get_logprobs_via_hf_api([{'role':'user','content':'Hello!'}])
|
||||||
|
print('text:', r['generated_text'])
|
||||||
|
print('first step top-k:', r['steps'][0]['topk'])
|
||||||
|
"
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
# Default base URL for transformers serve (OpenAI-compatible)
|
||||||
|
TRANSFORMERS_SERVE_URL = "http://localhost:8000/v1"
|
||||||
|
# Default base URL for our sc/hf_api (generate + logprobs)
|
||||||
|
HF_API_URL = "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_transformers_serve(
|
||||||
|
user_message: str,
|
||||||
|
*,
|
||||||
|
base_url: str = TRANSFORMERS_SERVE_URL,
|
||||||
|
model: str = "openai/gpt-oss-20b",
|
||||||
|
max_tokens: int = 256,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Send a chat message to a server running `transformers serve`.
|
||||||
|
Returns the assistant reply text. No logits/logprobs (server does not provide them).
|
||||||
|
"""
|
||||||
|
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": user_message}],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
resp = requests.post(url, json=payload, timeout=60)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
# Non-stream: choices[0].message.content
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return ""
|
||||||
|
msg = choices[0].get("message", {})
|
||||||
|
return msg.get("content", "") or ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_logprobs_via_hf_api(
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
*,
|
||||||
|
base_url: str = HF_API_URL,
|
||||||
|
max_new_tokens: int = 64,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Call our sc/hf_api POST /generate endpoint.
|
||||||
|
Returns generated text and per-token top-k logprobs (no raw logits over the wire).
|
||||||
|
"""
|
||||||
|
url = f"{base_url.rstrip('/')}/generate"
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"top_k": top_k,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"do_sample": False,
|
||||||
|
}
|
||||||
|
resp = requests.post(url, json=payload, timeout=120)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python -m sc.transformers_serve_client_example chat|logprobs [message]")
|
||||||
|
print(" chat -> call transformers serve /v1/chat/completions (no logprobs)")
|
||||||
|
print(" logprobs -> call sc/hf_api /generate (returns top-k logprobs)")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
cmd = sys.argv[1].lower()
|
||||||
|
message = (sys.argv[2] if len(sys.argv) > 2 else "Hello, how are you?").strip()
|
||||||
|
|
||||||
|
if cmd == "chat":
|
||||||
|
text = chat_with_transformers_serve(message)
|
||||||
|
print("Reply:", text)
|
||||||
|
elif cmd == "logprobs":
|
||||||
|
out = get_logprobs_via_hf_api([{"role": "user", "content": message}])
|
||||||
|
print("Generated:", out.get("generated_text", ""))
|
||||||
|
print("Steps (first 3):")
|
||||||
|
for s in out.get("steps", [])[:3]:
|
||||||
|
print(" token:", repr(s.get("token")), "logprob:", s.get("logprob"), "topk:", [t.get("token") for t in s.get("topk", [])[:5]])
|
||||||
|
else:
|
||||||
|
print("Unknown command. Use 'chat' or 'logprobs'.")
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user