File size: 2,550 Bytes
7166b9d
9a6b3b4
 
7166b9d
33ee4b1
e65e766
bff2feb
e65e766
 
 
 
9a6b3b4
 
cc25ca0
e65e766
 
 
b8408d1
bff2feb
 
 
 
 
 
 
b8408d1
bff2feb
 
 
 
 
 
 
 
 
 
 
e65e766
 
bff2feb
 
e65e766
bff2feb
 
 
e65e766
 
 
cc25ca0
9a6b3b4
 
 
 
 
 
 
 
 
 
 
 
7166b9d
bff2feb
9a6b3b4
b8408d1
 
7166b9d
9a6b3b4
7166b9d
9a6b3b4
7166b9d
9a6b3b4
e65e766
b8408d1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import logging
import os

# ตั้งค่า logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

try:
    model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # ตรวจสอบว่ามี GPU หรือไม่
    if torch.cuda.is_available():
        logger.info("GPU is available. Using CUDA.")
        device = "cuda"
    else:
        logger.info("No GPU found. Using CPU.")
        device = "cpu"

    # กำหนดการตั้งค่าสำหรับการโหลดโมเดล
    model_kwargs = {
        "torch_dtype": torch.float32 if device == "cpu" else torch.float16,
        "low_cpu_mem_usage": True,
    }

    if device == "cuda":
        from transformers import BitsAndBytesConfig
        model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)

    # โหลดโมเดล
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto" if device == "cuda" else None,
        **model_kwargs
    )

    model.to(device)
    logger.info(f"Model loaded successfully on {device}")
except Exception as e:
    logger.error(f"Error loading model: {str(e)}")
    raise

class Query(BaseModel):
    queryResult: Optional[dict] = None
    queryText: Optional[str] = None

@app.post("/webhook")
async def webhook(query: Query):
    try:
        user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
        
        if not user_query:
            raise HTTPException(status_code=400, detail="No query text provided")
        
        prompt = f"Human: {user_query}\nAI:"
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        
        ai_response = response.split("AI:")[-1].strip()
        
        return {"fulfillmentText": ai_response}
    except Exception as e:
        logger.error(f"Error in webhook: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)