commit for merge
This commit is contained in:
222
sc/hf_api.py
Normal file
222
sc/hf_api.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
HTTP API for HuggingFace causal LM inference (answer + logits).
|
||||
|
||||
This module exposes a small FastAPI service that:
|
||||
- Loads a local HuggingFace CausalLM once at startup
|
||||
- Accepts chat-style messages via HTTP
|
||||
- Returns generated text
|
||||
- Returns *logits-derived* information in a practical size:
|
||||
- top-k logprobs for each generated token (recommended)
|
||||
- optional prompt logits top-k for the final prompt position
|
||||
|
||||
Why not return full logits?
|
||||
Full logits are extremely large (seq_len x vocab_size) and will quickly
|
||||
overwhelm network/memory. This API defaults to returning top-k logprobs.
|
||||
|
||||
Run:
|
||||
MODEL_DIR=/path/to/model \\
|
||||
uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Example:
|
||||
curl -X POST http://localhost:8000/generate \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{"messages":[{"role":"user","content":"Hello!"}],"max_new_tokens":64,"top_k":10}'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = os.environ.get(
|
||||
"MODEL_DIR", "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||
)
|
||||
|
||||
app = FastAPI(title="HF LLM API", version="0.1.0")
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
max_new_tokens: int = Field(default=128, ge=1, le=1024)
|
||||
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
do_sample: bool = False
|
||||
top_k: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
# If True, also returns prompt last-position top-k logits (not full matrix)
|
||||
include_prompt_topk: bool = False
|
||||
|
||||
|
||||
class TokenTopK(BaseModel):
|
||||
token_id: int
|
||||
token: str
|
||||
logprob: float
|
||||
|
||||
|
||||
class GeneratedStep(BaseModel):
|
||||
token_id: int
|
||||
token: str
|
||||
logprob: float
|
||||
topk: List[TokenTopK]
|
||||
|
||||
|
||||
class PromptTopK(BaseModel):
|
||||
position: int
|
||||
topk: List[TokenTopK]
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
prompt: str
|
||||
generated_text: str
|
||||
generated_token_ids: List[int]
|
||||
steps: List[GeneratedStep]
|
||||
prompt_topk: Optional[PromptTopK] = None
|
||||
|
||||
|
||||
def _get_device() -> str:
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _load_model_and_tokenizer(model_dir: str):
|
||||
device = _get_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||
device_map="auto" if device == "cuda" else None,
|
||||
).eval()
|
||||
return tokenizer, model, device
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup_load():
|
||||
# Load once; shared across requests.
|
||||
global tokenizer, model, device
|
||||
tokenizer, model, device = _load_model_and_tokenizer(DEFAULT_MODEL_DIR)
|
||||
app.state.model_dir = DEFAULT_MODEL_DIR
|
||||
app.state.device = device
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": True,
|
||||
"model_dir": getattr(app.state, "model_dir", None),
|
||||
"device": getattr(app.state, "device", None),
|
||||
}
|
||||
|
||||
|
||||
def _apply_chat_template(messages: List[ChatMessage]) -> str:
|
||||
# Convert pydantic objects to plain dicts compatible with HF template.
|
||||
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
msg_dicts,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: naive concatenation
|
||||
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in msg_dicts]) + "\nassistant:"
|
||||
return prompt
|
||||
|
||||
|
||||
def _topk_from_logits(logits_1d: torch.Tensor, top_k: int) -> List[TokenTopK]:
|
||||
# logits_1d: (vocab,)
|
||||
top_vals, top_ids = torch.topk(logits_1d, k=top_k)
|
||||
# Convert to logprobs for interpretability
|
||||
logprobs = torch.log_softmax(logits_1d, dim=-1)
|
||||
out: List[TokenTopK] = []
|
||||
for tid in top_ids.tolist():
|
||||
tok = tokenizer.decode([tid])
|
||||
out.append(
|
||||
TokenTopK(
|
||||
token_id=int(tid),
|
||||
token=tok,
|
||||
logprob=float(logprobs[tid].detach().cpu().item()),
|
||||
)
|
||||
)
|
||||
# Sort in descending logprob (topk preserves order, but be explicit)
|
||||
out.sort(key=lambda x: x.logprob, reverse=True)
|
||||
return out
|
||||
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
def generate(req: GenerateRequest) -> GenerateResponse:
|
||||
if not hasattr(app.state, "model_dir"):
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||
|
||||
prompt = _apply_chat_template(req.messages)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
if device == "cuda":
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# Use generate() so we can get per-step scores (logits).
|
||||
# output.scores is a list with length = generated_tokens
|
||||
# each element shape: (batch, vocab)
|
||||
with torch.no_grad():
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
do_sample=req.do_sample if req.temperature > 0 else False,
|
||||
temperature=req.temperature if req.temperature > 0 else 1.0,
|
||||
top_p=req.top_p,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# Generated token ids include prompt + new tokens
|
||||
seq = out.sequences[0]
|
||||
prompt_len = int(inputs["input_ids"].shape[1])
|
||||
gen_token_ids = seq[prompt_len:].tolist()
|
||||
generated_text = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
|
||||
|
||||
# Build per-step top-k + chosen token logprob
|
||||
steps: List[GeneratedStep] = []
|
||||
if out.scores is None:
|
||||
raise HTTPException(status_code=500, detail="Model did not return scores")
|
||||
|
||||
for step_idx, step_logits in enumerate(out.scores):
|
||||
# step_logits: (1, vocab)
|
||||
step_logits_1d = step_logits[0]
|
||||
chosen_id = int(gen_token_ids[step_idx]) if step_idx < len(gen_token_ids) else None
|
||||
logprobs_1d = torch.log_softmax(step_logits_1d, dim=-1)
|
||||
chosen_logprob = float(logprobs_1d[chosen_id].detach().cpu().item()) if chosen_id is not None else float("nan")
|
||||
steps.append(
|
||||
GeneratedStep(
|
||||
token_id=chosen_id,
|
||||
token=tokenizer.decode([chosen_id]),
|
||||
logprob=chosen_logprob,
|
||||
topk=_topk_from_logits(step_logits_1d, req.top_k),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_topk: Optional[PromptTopK] = None
|
||||
if req.include_prompt_topk:
|
||||
with torch.no_grad():
|
||||
forward = model(**inputs)
|
||||
# forward.logits: (1, seq_len, vocab)
|
||||
last_pos = int(forward.logits.shape[1] - 1)
|
||||
last_logits = forward.logits[0, -1, :]
|
||||
prompt_topk = PromptTopK(position=last_pos, topk=_topk_from_logits(last_logits, req.top_k))
|
||||
|
||||
return GenerateResponse(
|
||||
prompt=prompt,
|
||||
generated_text=generated_text,
|
||||
generated_token_ids=[int(t) for t in gen_token_ids],
|
||||
steps=steps,
|
||||
prompt_topk=prompt_topk,
|
||||
)
|
||||
|
||||
@@ -1,23 +1,31 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
import asyncio
|
||||
from fire import Fire
|
||||
from sc.core.model import load_models
|
||||
from langchain_core.messages import HumanMessage
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||
device_map="auto" if device == "cuda" else None,
|
||||
).eval()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
|
||||
]
|
||||
|
||||
async def main():
|
||||
models = load_models(["ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b"])
|
||||
prompt = "nice to meet you."
|
||||
response = await models.invoke([HumanMessage(content=prompt)], logprobs=True, top_logprobs=20)
|
||||
print("Response:")
|
||||
print(response)
|
||||
# Convert chat messages -> a single prompt string using the model's chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
with torch.no_grad():
|
||||
out = model(**inputs)
|
||||
|
||||
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
|
||||
print("logits shape:", logits.shape)
|
||||
124
sc/transformers_serve_client_example.py
Normal file
124
sc/transformers_serve_client_example.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Call transformers serve (chat) and our HF API (logprobs) from Python with requests.
|
||||
|
||||
Summary:
|
||||
--------
|
||||
- transformers serve (transformers chat localhost:8000 --model-name-or-path ...):
|
||||
- Exposes OpenAI-compatible endpoints: /v1/chat/completions, /v1/responses, /v1/models.
|
||||
- It does NOT return logits or logprobs; the response chunks have "logprobs": null.
|
||||
- You can still use it for chat from Python via requests (see chat_with_transformers_serve).
|
||||
|
||||
- For answer + logprobs (top-k per token):
|
||||
- Use our custom API in sc/hf_api.py:
|
||||
MODEL_DIR=/path/to/model uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||
- Then call POST /generate with requests (see get_logprobs_via_hf_api).
|
||||
|
||||
Usage:
|
||||
------
|
||||
# Chat only (transformers serve on 8000):
|
||||
python -c "
|
||||
from sc.transformers_serve_client_example import chat_with_transformers_serve
|
||||
print(chat_with_transformers_serve('Hello!'))
|
||||
"
|
||||
|
||||
# Answer + logprobs (sc/hf_api on 8000):
|
||||
python -c "
|
||||
from sc.transformers_serve_client_example import get_logprobs_via_hf_api
|
||||
r = get_logprobs_via_hf_api([{'role':'user','content':'Hello!'}])
|
||||
print('text:', r['generated_text'])
|
||||
print('first step top-k:', r['steps'][0]['topk'])
|
||||
"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
# Default base URL for transformers serve (OpenAI-compatible)
|
||||
TRANSFORMERS_SERVE_URL = "http://localhost:8000/v1"
|
||||
# Default base URL for our sc/hf_api (generate + logprobs)
|
||||
HF_API_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
def chat_with_transformers_serve(
|
||||
user_message: str,
|
||||
*,
|
||||
base_url: str = TRANSFORMERS_SERVE_URL,
|
||||
model: str = "openai/gpt-oss-20b",
|
||||
max_tokens: int = 256,
|
||||
stream: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Send a chat message to a server running `transformers serve`.
|
||||
Returns the assistant reply text. No logits/logprobs (server does not provide them).
|
||||
"""
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_message}],
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
resp = requests.post(url, json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Non-stream: choices[0].message.content
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return ""
|
||||
msg = choices[0].get("message", {})
|
||||
return msg.get("content", "") or ""
|
||||
|
||||
|
||||
def get_logprobs_via_hf_api(
|
||||
messages: List[Dict[str, str]],
|
||||
*,
|
||||
base_url: str = HF_API_URL,
|
||||
max_new_tokens: int = 64,
|
||||
top_k: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call our sc/hf_api POST /generate endpoint.
|
||||
Returns generated text and per-token top-k logprobs (no raw logits over the wire).
|
||||
"""
|
||||
url = f"{base_url.rstrip('/')}/generate"
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"top_k": top_k,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False,
|
||||
}
|
||||
resp = requests.post(url, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m sc.transformers_serve_client_example chat|logprobs [message]")
|
||||
print(" chat -> call transformers serve /v1/chat/completions (no logprobs)")
|
||||
print(" logprobs -> call sc/hf_api /generate (returns top-k logprobs)")
|
||||
sys.exit(0)
|
||||
|
||||
cmd = sys.argv[1].lower()
|
||||
message = (sys.argv[2] if len(sys.argv) > 2 else "Hello, how are you?").strip()
|
||||
|
||||
if cmd == "chat":
|
||||
text = chat_with_transformers_serve(message)
|
||||
print("Reply:", text)
|
||||
elif cmd == "logprobs":
|
||||
out = get_logprobs_via_hf_api([{"role": "user", "content": message}])
|
||||
print("Generated:", out.get("generated_text", ""))
|
||||
print("Steps (first 3):")
|
||||
for s in out.get("steps", [])[:3]:
|
||||
print(" token:", repr(s.get("token")), "logprob:", s.get("logprob"), "topk:", [t.get("token") for t in s.get("topk", [])[:5]])
|
||||
else:
|
||||
print("Unknown command. Use 'chat' or 'logprobs'.")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user