ArqonzChat / app.py
Coots's picture
Update app.py
072b9fb verified
raw
history blame
2.05 kB
import os
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Setup Hugging Face cache directory
cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache")
os.makedirs(cache_dir, exist_ok=True)
# Optional token (for private models)
hf_token = os.getenv("HF_TOKEN")
# Load model and tokenizer
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
cache_dir=cache_dir,
device_map="auto",
torch_dtype="auto"
)
# Build generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
# Initialize FastAPI app
app = FastAPI()
# Serve static files like script.js
app.mount("/static", StaticFiles(directory="."), name="static")
# Route: Serve index.html at root
@app.get("/", response_class=HTMLResponse)
async def serve_home():
html_path = Path("index.html")
return HTMLResponse(content=html_path.read_text(), status_code=200)
# Route: Chat API
@app.post("/api")
async def ask_ai(request: Request):
data = await request.json()
question = data.get("question", "").strip()
if not question:
return JSONResponse(content={"answer": "❗ Please enter a valid question."})
prompt = f"[INST] {question} [/INST]"
try:
output = pipe(prompt)[0]["generated_text"]
return JSONResponse(content={"answer": output.strip()})
except Exception as e:
return JSONResponse(content={"answer": f"⚠️ Error: {str(e)}"})
# Optional: Serve script.js if not using /static path in HTML
@app.get("/script.js")
async def serve_script():
return FileResponse("script.js", media_type="application/javascript")