Spaces:
Runtime error
Runtime error
import os | |
import time | |
import uuid | |
from typing import List, Optional, Dict, Any | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import RedirectResponse | |
from pydantic import BaseModel, Field | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
MODEL_ID = os.getenv("MODEL_ID", "LiquidAI/LFM2-1.2B") | |
DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) | |
app = FastAPI(title="OpenAI-compatible API for LiquidAI/LFM2-1.2B") | |
tokenizer = None | |
model = None | |
def get_dtype() -> torch.dtype: | |
if torch.cuda.is_available(): | |
# Prefer bfloat16 if supported; else float16 | |
if torch.cuda.is_bf16_supported(): | |
return torch.bfloat16 | |
return torch.float16 | |
# CPU | |
return torch.float32 | |
def load_model(): | |
global tokenizer, model | |
dtype = get_dtype() | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=dtype, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
# Ensure eos/bos tokens exist | |
if tokenizer.eos_token is None: | |
tokenizer.eos_token = tokenizer.sep_token or tokenizer.pad_token or "</s>" | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
model: Optional[str] = Field(default=MODEL_ID) | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
max_tokens: Optional[int] = None | |
stop: Optional[List[str] | str] = None | |
n: Optional[int] = 1 | |
class CompletionRequest(BaseModel): | |
model: Optional[str] = Field(default=MODEL_ID) | |
prompt: str | List[str] | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
max_tokens: Optional[int] = None | |
stop: Optional[List[str] | str] = None | |
n: Optional[int] = 1 | |
class Usage(BaseModel): | |
prompt_tokens: int | |
completion_tokens: int | |
total_tokens: int | |
# Simple chat prompt formatter | |
def build_chat_prompt(messages: List[ChatMessage]) -> str: | |
system_prefix = "You are a helpful assistant." | |
system_msgs = [m.content for m in messages if m.role == "system"] | |
if system_msgs: | |
system_prefix = system_msgs[-1] | |
conv: List[str] = [f"System: {system_prefix}"] | |
for m in messages: | |
if m.role == "system": | |
continue | |
role = "User" if m.role == "user" else ("Assistant" if m.role == "assistant" else m.role.capitalize()) | |
conv.append(f"{role}: {m.content}") | |
conv.append("Assistant:") | |
return "\n".join(conv) | |
def apply_stop_sequences(text: str, stop: Optional[List[str] | str]) -> str: | |
if stop is None: | |
return text | |
stops = stop if isinstance(stop, list) else [stop] | |
cut = len(text) | |
for s in stops: | |
if not s: | |
continue | |
idx = text.find(s) | |
if idx != -1: | |
cut = min(cut, idx) | |
return text[:cut] | |
def generate_once(prompt: str, temperature: float, top_p: float, max_new_tokens: int) -> Dict[str, Any]: | |
assert tokenizer is not None and model is not None, "Model not loaded" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
gen_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True if temperature and temperature > 0 else False, | |
temperature=max(0.0, float(temperature or 0.0)), | |
top_p=max(0.0, float(top_p or 1.0)), | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
out = tokenizer.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
return { | |
"text": out, | |
"prompt_tokens": inputs["input_ids"].numel(), | |
"completion_tokens": gen_ids[0].shape[0] - inputs["input_ids"].shape[-1], | |
} | |
def root(): | |
return RedirectResponse(url="/docs") | |
def health(): | |
return {"status": "ok", "model": MODEL_ID} | |
def chat_completions(req: ChatCompletionRequest): | |
if req.n and req.n > 1: | |
raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") | |
max_new = req.max_tokens or DEFAULT_MAX_TOKENS | |
prompt = build_chat_prompt(req.messages) | |
g = generate_once(prompt, req.temperature or 0.7, req.top_p or 0.95, max_new) | |
text = apply_stop_sequences(g["text"], req.stop) | |
created = int(time.time()) | |
comp_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" | |
usage = Usage( | |
prompt_tokens=g["prompt_tokens"], | |
completion_tokens=g["completion_tokens"], | |
total_tokens=g["prompt_tokens"] + g["completion_tokens"], | |
) | |
return { | |
"id": comp_id, | |
"object": "chat.completion", | |
"created": created, | |
"model": req.model or MODEL_ID, | |
"choices": [ | |
{ | |
"index": 0, | |
"message": {"role": "assistant", "content": text}, | |
"finish_reason": "stop", | |
} | |
], | |
"usage": usage.dict(), | |
} | |
def completions(req: CompletionRequest): | |
if req.n and req.n > 1: | |
raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") | |
prompts = req.prompt if isinstance(req.prompt, list) else [req.prompt] | |
if len(prompts) != 1: | |
raise HTTPException(status_code=400, detail="Only a single prompt is supported in this simple server.") | |
max_new = req.max_tokens or DEFAULT_MAX_TOKENS | |
g = generate_once(prompts[0], req.temperature or 0.7, req.top_p or 0.95, max_new) | |
text = apply_stop_sequences(g["text"], req.stop) | |
created = int(time.time()) | |
comp_id = f"cmpl-{uuid.uuid4().hex[:24]}" | |
usage = Usage( | |
prompt_tokens=g["prompt_tokens"], | |
completion_tokens=g["completion_tokens"], | |
total_tokens=g["prompt_tokens"] + g["completion_tokens"], | |
) | |
return { | |
"id": comp_id, | |
"object": "text_completion", | |
"created": created, | |
"model": req.model or MODEL_ID, | |
"choices": [ | |
{ | |
"index": 0, | |
"text": text, | |
"finish_reason": "stop", | |
"logprobs": None, | |
} | |
], | |
"usage": usage.dict(), | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.getenv("PORT", "7860")) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |