ArqonzChat / app.py
Coots's picture
Update app.py
1b75a9a verified
raw
history blame
1.58 kB
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Set Hugging Face cache dir
cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache")
os.makedirs(cache_dir, exist_ok=True)
# Token for private models
hf_token = os.getenv("HF_TOKEN")
# Load model
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
cache_dir=cache_dir,
device_map="auto",
torch_dtype="auto"
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
# Create FastAPI app
app = FastAPI()
# Serve static files (JS, CSS, etc.)
app.mount("/static", StaticFiles(directory="."), name="static")
# Route: Serve index.html at root
@app.get("/", response_class=HTMLResponse)
async def root():
html_path = Path("index.html")
return HTMLResponse(content=html_path.read_text(), status_code=200)
# Route: Chat API
@app.post("/api")
async def ask_ai(request: Request):
data = await request.json()
question = data.get("question", "").strip()
if not question:
return JSONResponse(content={"answer": "❗ Please enter a valid question."})
prompt = f"[INST] {question} [/INST]"
output = pipe(prompt)[0]["generated_text"]
return JSONResponse(content={"answer": output.strip()})