Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from pathlib import Path | |
import os | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
# Set Hugging Face cache dir | |
cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
# Token for private models | |
hf_token = os.getenv("HF_TOKEN") | |
# Load model | |
model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
cache_dir=cache_dir, | |
device_map="auto", | |
torch_dtype="auto" | |
) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256) | |
# Create FastAPI app | |
app = FastAPI() | |
# Serve static files (JS, CSS, etc.) | |
app.mount("/static", StaticFiles(directory="."), name="static") | |
# Route: Serve index.html at root | |
async def root(): | |
html_path = Path("index.html") | |
return HTMLResponse(content=html_path.read_text(), status_code=200) | |
# Route: Chat API | |
async def ask_ai(request: Request): | |
data = await request.json() | |
question = data.get("question", "").strip() | |
if not question: | |
return JSONResponse(content={"answer": "β Please enter a valid question."}) | |
prompt = f"[INST] {question} [/INST]" | |
output = pipe(prompt)[0]["generated_text"] | |
return JSONResponse(content={"answer": output.strip()}) | |