Sirawitch commited on
Commit
b8408d1
·
verified ·
1 Parent(s): 2ef8549

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -9,7 +9,26 @@ app = FastAPI()
9
  # โหลดโมเดลและ tokenizer
10
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class Query(BaseModel):
15
  queryResult: Optional[dict] = None
@@ -25,9 +44,10 @@ async def webhook(query: Query):
25
 
26
  # สร้าง prompt และ generate ข้อความ
27
  prompt = f"Human: {user_query}\nAI:"
28
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
29
 
30
- output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
 
31
  response = tokenizer.decode(output[0], skip_special_tokens=True)
32
 
33
  # แยกส่วนที่เป็นคำตอบของ AI
@@ -35,4 +55,8 @@ async def webhook(query: Query):
35
 
36
  return {"fulfillmentText": ai_response}
37
  except Exception as e:
38
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
9
  # โหลดโมเดลและ tokenizer
10
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ # ตรวจสอบว่ามี GPU หรือไม่
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # โหลดโมเดลด้วยการตั้งค่าที่เหมาะสม
17
+ if device == "cuda":
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ low_cpu_mem_usage=True
23
+ )
24
+ else:
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float32,
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ model.to(device)
32
 
33
  class Query(BaseModel):
34
  queryResult: Optional[dict] = None
 
44
 
45
  # สร้าง prompt และ generate ข้อความ
46
  prompt = f"Human: {user_query}\nAI:"
47
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
48
 
49
+ with torch.no_grad():
50
+ output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
51
  response = tokenizer.decode(output[0], skip_special_tokens=True)
52
 
53
  # แยกส่วนที่เป็นคำตอบของ AI
 
55
 
56
  return {"fulfillmentText": ai_response}
57
  except Exception as e:
58
+ raise HTTPException(status_code=500, detail=str(e))
59
+
60
+ if __name__ == "__main__":
61
+ import uvicorn
62
+ uvicorn.run(app, host="0.0.0.0", port=7860)