ArqonzChat / app.py
Coots's picture
Update app.py
9040bc1 verified
raw
history blame
1.72 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Set up safe cache directory for Hugging Face
cache_dir = os.getenv("TRANSFORMERS_CACHE", "/cache") # Use environment variable or default to /cache
os.makedirs(cache_dir, exist_ok=True)
# Optional: Use token only if you're accessing a private model
hf_token = os.getenv("HF_TOKEN")
# Load tokenizer and model
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
cache_dir=cache_dir,
device_map="auto", # or "cpu" if no GPU
torch_dtype="auto" # will default to float32 on CPU
)
except Exception as e:
raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")
# Load pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
)
# Init FastAPI app
app = FastAPI()
@app.post("/api")
async def ask_ai(request: Request):
try:
data = await request.json()
question = data.get("question", "").strip()
if not question:
return JSONResponse(content={"answer": "❗ Please enter a valid question."})
prompt = f"[INST] {question} [/INST]"
output = pipe(prompt)[0]["generated_text"]
return JSONResponse(content={"answer": output.strip()})
except Exception as e:
return JSONResponse(content={"answer": f"⚠️ Error: {str(e)}"})