laserbeam2045 commited on
Commit
9736832
·
1 Parent(s): a9bf179
Files changed (2) hide show
  1. app.py +27 -21
  2. requirements.txt +0 -2
app.py CHANGED
@@ -9,37 +9,44 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
9
  # 設定
10
  # -----------------------------------------------------------------------------
11
  MODEL_ID = "google/gemma-3-4b-it"
12
- # Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
14
 
15
  # -----------------------------------------------------------------------------
16
- # デバイス設定(Spaces の無料枠では CPU のみです)
17
  # -----------------------------------------------------------------------------
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
19
 
20
  # -----------------------------------------------------------------------------
21
- # トークナイザーとモデルのロード
22
  # -----------------------------------------------------------------------------
23
- tokenizer = AutoTokenizer.from_pretrained(
24
  MODEL_ID,
25
  token=HF_TOKEN,
26
  trust_remote_code=True,
27
  torch_dtype=torch.float32,
28
  low_cpu_mem_usage=True
29
- )
30
 
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_ID,
33
- token=HF_TOKEN,
34
- torch_dtype=torch.float32, # CPU 環境では float32
35
- device_map="auto" if torch.cuda.is_available() else None
 
 
 
 
36
  )
37
- model.to(device)
38
 
39
  # -----------------------------------------------------------------------------
40
- # FastAPI 定義
41
  # -----------------------------------------------------------------------------
42
- app = FastAPI(title="Gemma3-4B-IT API")
43
 
44
  class GenerationRequest(BaseModel):
45
  prompt: str
@@ -50,16 +57,15 @@ class GenerationRequest(BaseModel):
50
  @app.post("/generate")
51
  async def generate(req: GenerationRequest):
52
  if not req.prompt:
53
- raise HTTPException(status_code=400, detail="prompt は必須です。")
54
- # トークナイズ
55
  inputs = tokenizer(
56
  req.prompt,
57
  return_tensors="pt",
58
- padding=True,
59
  truncation=True,
60
- ).to(device)
61
- # 生成
62
- generation_output = model.generate(
63
  **inputs,
64
  max_new_tokens=req.max_new_tokens,
65
  temperature=req.temperature,
@@ -67,7 +73,7 @@ async def generate(req: GenerationRequest):
67
  do_sample=True,
68
  pad_token_id=tokenizer.eos_token_id
69
  )
70
- text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
71
  return {"generated_text": text}
72
 
73
  # -----------------------------------------------------------------------------
 
9
  # 設定
10
  # -----------------------------------------------------------------------------
11
  MODEL_ID = "google/gemma-3-4b-it"
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+ DEVICE = torch.device("cpu") # 無料枠は CPU のみ
14
 
15
  # -----------------------------------------------------------------------------
16
+ # トークナイザーのロード
17
  # -----------------------------------------------------------------------------
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ MODEL_ID,
20
+ token=HF_TOKEN,
21
+ trust_remote_code=True
22
+ )
23
 
24
  # -----------------------------------------------------------------------------
25
+ # モデルのロード+低メモリモード
26
  # -----------------------------------------------------------------------------
27
+ model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_ID,
29
  token=HF_TOKEN,
30
  trust_remote_code=True,
31
  torch_dtype=torch.float32,
32
  low_cpu_mem_usage=True
33
+ ).to(DEVICE)
34
 
35
+ # -----------------------------------------------------------------------------
36
+ # 動的量子化の適用
37
+ # -----------------------------------------------------------------------------
38
+ # - {torch.nn.Linear} INT8
39
+ # - dtype=torch.qint8 で重みのみ量子化
40
+ model = torch.quantization.quantize_dynamic(
41
+ model,
42
+ {torch.nn.Linear},
43
+ dtype=torch.qint8
44
  )
 
45
 
46
  # -----------------------------------------------------------------------------
47
+ # FastAPI サーバー定義
48
  # -----------------------------------------------------------------------------
49
+ app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization")
50
 
51
  class GenerationRequest(BaseModel):
52
  prompt: str
 
57
  @app.post("/generate")
58
  async def generate(req: GenerationRequest):
59
  if not req.prompt:
60
+ raise HTTPException(status_code=400, detail="`prompt` は必須です。")
61
+ # トークナイズして推論
62
  inputs = tokenizer(
63
  req.prompt,
64
  return_tensors="pt",
 
65
  truncation=True,
66
+ padding=True
67
+ ).to(DEVICE)
68
+ output_ids = model.generate(
69
  **inputs,
70
  max_new_tokens=req.max_new_tokens,
71
  temperature=req.temperature,
 
73
  do_sample=True,
74
  pad_token_id=tokenizer.eos_token_id
75
  )
76
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
77
  return {"generated_text": text}
78
 
79
  # -----------------------------------------------------------------------------
requirements.txt CHANGED
@@ -2,5 +2,3 @@ fastapi
2
  uvicorn[standard]
3
  transformers>=4.50.0.dev0
4
  torch
5
- accelerate>=0.9.0
6
- safetensors
 
2
  uvicorn[standard]
3
  transformers>=4.50.0.dev0
4
  torch