kkulchatbot / app.py
Sirawitch's picture
Update app.py
a90d622 verified
raw
history blame
2.09 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
try:
model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# 4-bit quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
class Query(BaseModel):
queryResult: Optional[dict] = None
queryText: Optional[str] = None
@app.post("/webhook")
async def webhook(query: Query):
try:
user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
if not user_query:
raise HTTPException(status_code=400, detail="No query text provided")
prompt = f"Human: {user_query}\nAI:"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(output[0], skip_special_tokens=True)
ai_response = response.split("AI:")[-1].strip()
return {"fulfillmentText": ai_response}
except Exception as e:
logger.error(f"Error in webhook: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)