Spaces:
Sleeping
Sleeping
File size: 3,177 Bytes
9d9d39a cdbdba1 ec4633f 43a49a4 cf3741b 4633b64 45c840a ec4633f c0132d6 cdbdba1 ec4633f cdbdba1 9d9d39a cdbdba1 ec4633f cdbdba1 9d9d39a 45c840a ec4633f ddfcea6 45c840a ec4633f 45c840a 384689e ae0f1b9 9d9d39a 2cdd46e 72505c7 02da8f3 72505c7 9d9d39a 2cdd46e c0132d6 45c840a 9d9d39a ec4633f c0132d6 ec4633f cdbdba1 ec4633f c0132d6 ec4633f cdbdba1 ec4633f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from llama_cpp import Llama
from fastapi.responses import PlainTextResponse, JSONResponse
import os
import time
import uuid
app = FastAPI()
llm = None
# Models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 256
class ModelInfo(BaseModel):
id: str
name: str
description: str
# Load your models info here or dynamically from disk/config
AVAILABLE_MODELS = [
ModelInfo(id="llama2", name="Llama 2", description="Meta Llama 2 model"),
# Add more models if you want
]
@app.on_event("startup")
def load_model():
global llm
model_path_file = "/tmp/model_path.txt"
if not os.path.exists(model_path_file):
raise RuntimeError(f"Model path file not found: {model_path_file}")
with open(model_path_file, "r") as f:
model_path = f.read().strip()
if not os.path.exists(model_path):
raise RuntimeError(f"Model not found at path: {model_path}")
llm = Llama(model_path=model_path)
@app.get("/", response_class=PlainTextResponse)
async def root():
return "Ollama is running"
@app.get("/health")
async def health_check():
return {"status": "ok"}
@app.get("/api/tags")
async def api_tags():
return JSONResponse(content={
"models": [
{
"name": "phi-2",
"modified_at": "2025-06-01T00:00:00Z",
"size": 2147483648,
"digest": "sha256:placeholderdigest",
"details": {
"format": "gguf",
"family": "phi",
"families": ["phi"]
}
}
]
})
@app.get("/models")
async def list_models():
# Return available models info
return [model.dict() for model in AVAILABLE_MODELS]
@app.get("/models/{model_id}")
async def get_model(model_id: str):
for model in AVAILABLE_MODELS:
if model.id == model_id:
return model.dict()
raise HTTPException(status_code=404, detail="Model not found")
@app.post("/chat")
async def chat(req: ChatRequest):
global llm
if llm is None:
return {"error": "Model not initialized."}
# Validate model - simple check
if req.model not in [m.id for m in AVAILABLE_MODELS]:
raise HTTPException(status_code=400, detail="Unsupported model")
# Construct prompt from messages
prompt = ""
for m in req.messages:
prompt += f"{m.role}: {m.content}\n"
prompt += "assistant:"
output = llm(
prompt,
max_tokens=req.max_tokens,
temperature=req.temperature,
stop=["user:", "assistant:"]
)
text = output.get("choices", [{}])[0].get("text", "").strip()
response = {
"id": str(uuid.uuid4()),
"model": req.model,
"choices": [
{
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}
]
}
return response |