Sirawitch commited on
Commit
33ee4b1
·
verified ·
1 Parent(s): b8408d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -23
app.py CHANGED
@@ -1,34 +1,25 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
 
6
 
7
  app = FastAPI()
8
 
9
- # โหลดโมเดลและ tokenizer
10
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
- # ตรวจสอบว่ามี GPU หรือไม่
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
- # โหลดโมเดลด้วยการตั้งค่าที่เหมาะสม
17
- if device == "cuda":
18
- model = AutoModelForCausalLM.from_pretrained(
19
- model_name,
20
- torch_dtype=torch.float16,
21
- device_map="auto",
22
- low_cpu_mem_usage=True
23
- )
24
- else:
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- torch_dtype=torch.float32,
28
- low_cpu_mem_usage=True
29
- )
30
-
31
- model.to(device)
32
 
33
  class Query(BaseModel):
34
  queryResult: Optional[dict] = None
@@ -42,15 +33,13 @@ async def webhook(query: Query):
42
  if not user_query:
43
  raise HTTPException(status_code=400, detail="No query text provided")
44
 
45
- # สร้าง prompt และ generate ข้อความ
46
  prompt = f"Human: {user_query}\nAI:"
47
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
48
 
49
  with torch.no_grad():
50
  output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
51
  response = tokenizer.decode(output[0], skip_special_tokens=True)
52
 
53
- # แยกส่วนที่เป็นคำตอบของ AI
54
  ai_response = response.split("AI:")[-1].strip()
55
 
56
  return {"fulfillmentText": ai_response}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional
 
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
6
 
7
  app = FastAPI()
8
 
 
9
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
12
+ # ใช้ BitsAndBytes สำหรับ quantization
13
+ config = AutoConfig.from_pretrained(model_name)
14
+ config.quantization_config = BitsAndBytesConfig(load_in_8bit=True)
15
 
16
+ # โหลดโมเดลด้วย 8-bit quantization
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ config=config,
20
+ device_map="auto",
21
+ torch_dtype=torch.float16,
22
+ )
 
 
 
 
 
 
 
 
 
23
 
24
  class Query(BaseModel):
25
  queryResult: Optional[dict] = None
 
33
  if not user_query:
34
  raise HTTPException(status_code=400, detail="No query text provided")
35
 
 
36
  prompt = f"Human: {user_query}\nAI:"
37
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
38
 
39
  with torch.no_grad():
40
  output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
41
  response = tokenizer.decode(output[0], skip_special_tokens=True)
42
 
 
43
  ai_response = response.split("AI:")[-1].strip()
44
 
45
  return {"fulfillmentText": ai_response}