Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from llama_cpp import Llama | |
from fastapi.responses import PlainTextResponse, JSONResponse | |
import os | |
import time | |
import uuid | |
app = FastAPI() | |
llm = None | |
# Models | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = 256 | |
class ModelInfo(BaseModel): | |
id: str | |
name: str | |
description: str | |
# Load your models info here or dynamically from disk/config | |
AVAILABLE_MODELS = [ | |
ModelInfo(id="llama2", name="Llama 2", description="Meta Llama 2 model"), | |
# Add more models if you want | |
] | |
def load_model(): | |
global llm | |
model_path_file = "/tmp/model_path.txt" | |
if not os.path.exists(model_path_file): | |
raise RuntimeError(f"Model path file not found: {model_path_file}") | |
with open(model_path_file, "r") as f: | |
model_path = f.read().strip() | |
if not os.path.exists(model_path): | |
raise RuntimeError(f"Model not found at path: {model_path}") | |
llm = Llama(model_path=model_path) | |
async def root(): | |
return "Ollama is running" | |
async def health_check(): | |
return {"status": "ok"} | |
async def api_tags(): | |
return JSONResponse(content={ | |
"models": [ | |
{ | |
"name": "phi-2", | |
"modified_at": "2025-06-01T00:00:00Z", | |
"size": 2147483648, | |
"digest": "sha256:placeholderdigest", | |
"details": { | |
"format": "gguf", | |
"family": "phi", | |
"families": ["phi"] | |
} | |
} | |
] | |
}) | |
async def list_models(): | |
# Return available models info | |
return [model.dict() for model in AVAILABLE_MODELS] | |
async def get_model(model_id: str): | |
for model in AVAILABLE_MODELS: | |
if model.id == model_id: | |
return model.dict() | |
raise HTTPException(status_code=404, detail="Model not found") | |
async def chat(req: ChatRequest): | |
global llm | |
if llm is None: | |
return {"error": "Model not initialized."} | |
# Validate model - simple check | |
if req.model not in [m.id for m in AVAILABLE_MODELS]: | |
raise HTTPException(status_code=400, detail="Unsupported model") | |
# Construct prompt from messages | |
prompt = "" | |
for m in req.messages: | |
prompt += f"{m.role}: {m.content}\n" | |
prompt += "assistant:" | |
output = llm( | |
prompt, | |
max_tokens=req.max_tokens, | |
temperature=req.temperature, | |
stop=["user:", "assistant:"] | |
) | |
text = output.get("choices", [{}])[0].get("text", "").strip() | |
response = { | |
"id": str(uuid.uuid4()), | |
"model": req.model, | |
"choices": [ | |
{ | |
"message": {"role": "assistant", "content": text}, | |
"finish_reason": "stop" | |
} | |
] | |
} | |
return response |