kkulchatbot / app.py
Sirawitch's picture
Update app.py
bff2feb verified
raw
history blame
2.55 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import logging
import os
# ตั้งค่า logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
try:
model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ตรวจสอบว่ามี GPU หรือไม่
if torch.cuda.is_available():
logger.info("GPU is available. Using CUDA.")
device = "cuda"
else:
logger.info("No GPU found. Using CPU.")
device = "cpu"
# กำหนดการตั้งค่าสำหรับการโหลดโมเดล
model_kwargs = {
"torch_dtype": torch.float32 if device == "cpu" else torch.float16,
"low_cpu_mem_usage": True,
}
if device == "cuda":
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
# โหลดโมเดล
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto" if device == "cuda" else None,
**model_kwargs
)
model.to(device)
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
class Query(BaseModel):
queryResult: Optional[dict] = None
queryText: Optional[str] = None
@app.post("/webhook")
async def webhook(query: Query):
try:
user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
if not user_query:
raise HTTPException(status_code=400, detail="No query text provided")
prompt = f"Human: {user_query}\nAI:"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(output[0], skip_special_tokens=True)
ai_response = response.split("AI:")[-1].strip()
return {"fulfillmentText": ai_response}
except Exception as e:
logger.error(f"Error in webhook: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)