| """ |
| ╔══════════════════════════════════════════════════════════════╗ |
| ║ Granite 4.0 ONNX Inference Server ║ |
| ║ Model: onnx-community/granite-4.0-h-350m-ONNX ║ |
| ╚══════════════════════════════════════════════════════════════╝ |
| """ |
|
|
| import asyncio |
| import time |
| import uuid |
| import threading |
| from collections import deque |
| from contextlib import asynccontextmanager |
| from typing import AsyncGenerator, List, Optional |
|
|
| import numpy as np |
| import onnxruntime |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse, StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
| from huggingface_hub import snapshot_download |
| from pydantic import BaseModel |
| from transformers import AutoConfig, AutoTokenizer |
|
|
| |
| MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX" |
| MODEL_FILENAME = "model_q4" |
|
|
| decoder_session = None |
| tokenizer = None |
| config = None |
|
|
| |
| metrics = { |
| "total_requests": 0, |
| "active_requests": 0, |
| "total_tokens_generated": 0, |
| "total_prompt_tokens": 0, |
| "request_latencies": deque(maxlen=100), |
| "tokens_per_second_history": deque(maxlen=50), |
| "errors": 0, |
| "start_time": time.time(), |
| "last_tps": 0.0, |
| "model_loaded": False, |
| "model_loading": True, |
| } |
| metrics_lock = threading.Lock() |
|
|
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class ChatRequest(BaseModel): |
| messages: List[Message] |
| max_new_tokens: int = 512 |
| temperature: float = 1.0 |
| stream: bool = False |
|
|
|
|
| class ChatResponse(BaseModel): |
| id: str |
| content: str |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
| latency_ms: float |
| tokens_per_second: float |
|
|
|
|
| |
| def load_model(): |
| global decoder_session, tokenizer, config |
| print(f"[INFO] Downloading model {MODEL_ID}...") |
|
|
| try: |
| model_dir = snapshot_download( |
| MODEL_ID, |
| ignore_patterns=["*.msgpack", "*.h5", "flax_model*", |
| "model.onnx", "model_fp16.onnx", "model_q4f16.onnx"], |
| ) |
| import os |
| model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx") |
|
|
| print(f"[INFO] Loading ONNX session from {model_path}...") |
| sess_options = onnxruntime.SessionOptions() |
| sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
| sess_options.intra_op_num_threads = 4 |
|
|
| decoder_session = onnxruntime.InferenceSession( |
| model_path, |
| sess_options=sess_options, |
| providers=["CPUExecutionProvider"], |
| ) |
|
|
| print("[INFO] Loading tokenizer and config...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| config = AutoConfig.from_pretrained(MODEL_ID) |
|
|
| with metrics_lock: |
| metrics["model_loaded"] = True |
| metrics["model_loading"] = False |
|
|
| print("[INFO] ✅ Model loaded successfully!") |
|
|
| except Exception as e: |
| with metrics_lock: |
| metrics["model_loading"] = False |
| metrics["errors"] += 1 |
| print(f"[ERROR] Failed to load model: {e}") |
| raise |
|
|
|
|
| |
| def init_cache(batch_size: int, dtype=np.float32): |
| cache = {} |
| head_dim = config.hidden_size // config.num_attention_heads |
| d_conv = config.mamba_d_conv |
| mamba_expand = config.mamba_expand |
| mamba_n_groups = config.mamba_n_groups |
| mamba_d_state = config.mamba_d_state |
| conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state) |
|
|
| for i, layer_type in enumerate(config.layer_types): |
| if layer_type == "attention": |
| for kv in ("key", "value"): |
| cache[f"past_key_values.{i}.{kv}"] = np.zeros( |
| [batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype |
| ) |
| elif layer_type == "mamba": |
| cache[f"past_conv.{i}"] = np.zeros( |
| [batch_size, conv_d_inner, d_conv], dtype=dtype |
| ) |
| cache[f"past_ssm.{i}"] = np.zeros( |
| [batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype |
| ) |
| return cache |
|
|
|
|
| |
| def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray, |
| max_new_tokens: int = 512) -> AsyncGenerator: |
| """Synchronous token generation — yields (token_str, is_done)""" |
| dtype = np.float32 |
| cache = init_cache(batch_size=1, dtype=dtype) |
| output_names = [o.name for o in decoder_session.get_outputs()] |
| eos_token_id = config.eos_token_id if not isinstance( |
| config.eos_token_id, list) else config.eos_token_id[0] |
|
|
| generated = [] |
| t_start = time.perf_counter() |
|
|
| for step in range(max_new_tokens): |
| feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask} |
| outputs = decoder_session.run(None, feed_dict | cache) |
| named_outputs = dict(zip(output_names, outputs)) |
|
|
| next_token = outputs[0][:, -1].argmax(-1, keepdims=True) |
| attention_mask = np.concatenate( |
| [attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1 |
| ) |
| input_ids = next_token |
|
|
| for name in cache: |
| new_name = name.replace("past_key_values", "present").replace("past_", "present_") |
| cache[name] = named_outputs[new_name] |
|
|
| token_id = int(next_token[0, 0]) |
| generated.append(token_id) |
|
|
| token_str = tokenizer.decode([token_id], skip_special_tokens=True) |
| elapsed = time.perf_counter() - t_start |
| tps = (step + 1) / elapsed if elapsed > 0 else 0 |
|
|
| is_done = token_id == eos_token_id |
| yield token_str, is_done, tps |
|
|
| if is_done: |
| break |
|
|
| return generated |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| loop = asyncio.get_event_loop() |
| await loop.run_in_executor(None, load_model) |
| yield |
|
|
|
|
| |
| app = FastAPI( |
| title="Granite 4.0 ONNX Server", |
| description="High-performance inference server for granite-4.0-h-350m-ONNX", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| @app.get("/health") |
| def health(): |
| with metrics_lock: |
| return { |
| "status": "ready" if metrics["model_loaded"] else "loading", |
| "model": MODEL_ID, |
| "uptime_seconds": round(time.time() - metrics["start_time"], 1), |
| } |
|
|
|
|
| @app.get("/metrics") |
| def get_metrics(): |
| with metrics_lock: |
| uptime = time.time() - metrics["start_time"] |
| avg_latency = ( |
| sum(metrics["request_latencies"]) / len(metrics["request_latencies"]) |
| if metrics["request_latencies"] else 0 |
| ) |
| return { |
| "uptime_seconds": round(uptime, 1), |
| "total_requests": metrics["total_requests"], |
| "active_requests": metrics["active_requests"], |
| "total_tokens_generated": metrics["total_tokens_generated"], |
| "total_prompt_tokens": metrics["total_prompt_tokens"], |
| "average_latency_ms": round(avg_latency, 2), |
| "last_tokens_per_second": round(metrics["last_tps"], 2), |
| "tps_history": list(metrics["tokens_per_second_history"]), |
| "errors": metrics["errors"], |
| "model_loaded": metrics["model_loaded"], |
| "model_loading": metrics["model_loading"], |
| "requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2), |
| } |
|
|
|
|
| @app.post("/chat", response_model=ChatResponse) |
| async def chat(req: ChatRequest): |
| if not metrics["model_loaded"]: |
| raise HTTPException(status_code=503, detail="Model still loading, please wait...") |
|
|
| with metrics_lock: |
| metrics["total_requests"] += 1 |
| metrics["active_requests"] += 1 |
|
|
| t0 = time.perf_counter() |
| request_id = str(uuid.uuid4())[:8] |
|
|
| try: |
| messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| loop = asyncio.get_event_loop() |
|
|
| inputs = await loop.run_in_executor( |
| None, |
| lambda: tokenizer.apply_chat_template( |
| messages, add_generation_prompt=True, |
| tokenize=True, return_dict=True, return_tensors="np" |
| ) |
| ) |
|
|
| input_ids = inputs["input_ids"] |
| attention_mask = inputs["attention_mask"] |
| prompt_tokens = int(input_ids.shape[1]) |
|
|
| full_text = "" |
| final_tps = 0.0 |
| completion_tokens = 0 |
|
|
| def run_generation(): |
| nonlocal full_text, final_tps, completion_tokens |
| for token_str, is_done, tps in generate_tokens( |
| input_ids, attention_mask, req.max_new_tokens |
| ): |
| full_text += token_str |
| completion_tokens += 1 |
| final_tps = tps |
| if is_done: |
| break |
|
|
| await loop.run_in_executor(None, run_generation) |
|
|
| latency_ms = (time.perf_counter() - t0) * 1000 |
|
|
| with metrics_lock: |
| metrics["active_requests"] -= 1 |
| metrics["total_tokens_generated"] += completion_tokens |
| metrics["total_prompt_tokens"] += prompt_tokens |
| metrics["request_latencies"].append(latency_ms) |
| metrics["tokens_per_second_history"].append(round(final_tps, 2)) |
| metrics["last_tps"] = final_tps |
|
|
| return ChatResponse( |
| id=request_id, |
| content=full_text, |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens, |
| latency_ms=round(latency_ms, 2), |
| tokens_per_second=round(final_tps, 2), |
| ) |
|
|
| except Exception as e: |
| with metrics_lock: |
| metrics["active_requests"] -= 1 |
| metrics["errors"] += 1 |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/chat/stream") |
| async def chat_stream(req: ChatRequest): |
| if not metrics["model_loaded"]: |
| raise HTTPException(status_code=503, detail="Model still loading...") |
|
|
| with metrics_lock: |
| metrics["total_requests"] += 1 |
| metrics["active_requests"] += 1 |
|
|
| messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| inputs = tokenizer.apply_chat_template( |
| messages, add_generation_prompt=True, |
| tokenize=True, return_dict=True, return_tensors="np" |
| ) |
|
|
| input_ids = inputs["input_ids"] |
| attention_mask = inputs["attention_mask"] |
|
|
| async def event_stream(): |
| completion_tokens = 0 |
| try: |
| loop = asyncio.get_event_loop() |
| gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens) |
|
|
| def next_token(): |
| return next(gen, None) |
|
|
| while True: |
| result = await loop.run_in_executor(None, next_token) |
| if result is None: |
| break |
| token_str, is_done, tps = result |
| completion_tokens += 1 |
| yield f"data: {token_str}\n\n" |
| if is_done: |
| break |
|
|
| yield f"data: [DONE]\n\n" |
| finally: |
| with metrics_lock: |
| metrics["active_requests"] -= 1 |
| metrics["total_tokens_generated"] += completion_tokens |
|
|
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def ui(): |
| with open("/app/static/index.html") as f: |
| return f.read() |
|
|