laserbeam2045 commited on
Commit
6dd176e
·
1 Parent(s): 9736832
Files changed (1) hide show
  1. app.py +34 -44
app.py CHANGED
@@ -1,52 +1,49 @@
1
  # app.py
2
  import os
3
- import torch
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
 
8
  # -----------------------------------------------------------------------------
9
- # 設定
10
  # -----------------------------------------------------------------------------
11
- MODEL_ID = "google/gemma-3-4b-it"
12
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
- DEVICE = torch.device("cpu") # 無料枠は CPU のみ
 
 
14
 
15
- # -----------------------------------------------------------------------------
16
- # トークナイザーのロード
17
- # -----------------------------------------------------------------------------
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- MODEL_ID,
20
- token=HF_TOKEN,
21
- trust_remote_code=True
22
- )
23
 
24
  # -----------------------------------------------------------------------------
25
- # モデルのロード+低メモリモード
26
  # -----------------------------------------------------------------------------
27
- model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_ID,
29
- token=HF_TOKEN,
30
- trust_remote_code=True,
31
- torch_dtype=torch.float32,
32
- low_cpu_mem_usage=True
33
- ).to(DEVICE)
 
 
 
34
 
35
  # -----------------------------------------------------------------------------
36
- # 動的量子化の適用
37
  # -----------------------------------------------------------------------------
38
- # - {torch.nn.Linear} を INT8 化
39
- # - dtype=torch.qint8 で重みのみ量子化
40
- model = torch.quantization.quantize_dynamic(
41
- model,
42
- {torch.nn.Linear},
43
- dtype=torch.qint8
44
  )
45
 
46
  # -----------------------------------------------------------------------------
47
- # FastAPI サーバー定義
48
  # -----------------------------------------------------------------------------
49
- app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization")
50
 
51
  class GenerationRequest(BaseModel):
52
  prompt: str
@@ -58,22 +55,15 @@ class GenerationRequest(BaseModel):
58
  async def generate(req: GenerationRequest):
59
  if not req.prompt:
60
  raise HTTPException(status_code=400, detail="`prompt` は必須です。")
61
- # トークナイズして推論
62
- inputs = tokenizer(
63
  req.prompt,
64
- return_tensors="pt",
65
- truncation=True,
66
- padding=True
67
- ).to(DEVICE)
68
- output_ids = model.generate(
69
- **inputs,
70
- max_new_tokens=req.max_new_tokens,
71
- temperature=req.temperature,
72
  top_p=req.top_p,
73
- do_sample=True,
74
- pad_token_id=tokenizer.eos_token_id
 
 
75
  )
76
- text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
77
  return {"generated_text": text}
78
 
79
  # -----------------------------------------------------------------------------
 
1
  # app.py
2
  import os
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
+ from huggingface_hub import hf_hub_download
6
+ from pyllamacpp.model import Model
7
 
8
  # -----------------------------------------------------------------------------
9
+ # Hugging Face Hub の設定
10
  # -----------------------------------------------------------------------------
11
+ HF_TOKEN = os.environ.get("HF_TOKEN") # 必要に応じて Secrets にセット
12
+ REPO_ID = "google/gemma-3-12b-it-qat-q4_0-gguf"
13
+ # 実際にリポジトリに置かれている GGUF ファイル名を確認してください。
14
+ # 例: "gemma-3-12b-it-qat-q4_0-gguf.gguf"
15
+ GGUF_FILENAME = "gemma-3-12b-it-qat-q4_0-gguf.gguf"
16
 
17
+ # キャッシュ先のパス(リポジトリ直下に置く場合)
18
+ MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME)
 
 
 
 
 
 
19
 
20
  # -----------------------------------------------------------------------------
21
+ # 起動時に一度だけダウンロード
22
  # -----------------------------------------------------------------------------
23
+ if not os.path.exists(MODEL_PATH):
24
+ print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …")
25
+ hf_hub_download(
26
+ repo_id=REPO_ID,
27
+ filename=GGUF_FILENAME,
28
+ token=HF_TOKEN,
29
+ repo_type="model", # 明示的にモデルリポジトリを指定
30
+ local_dir=os.getcwd(), # カレントディレクトリに保存
31
+ local_dir_use_symlinks=False
32
+ )
33
 
34
  # -----------------------------------------------------------------------------
35
+ # llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード
36
  # -----------------------------------------------------------------------------
37
+ llm = Model(
38
+ model_path=MODEL_PATH,
39
+ n_ctx=512, # 必要に応じて調整
40
+ n_threads=4, # 実マシンのコア数に合わせて
 
 
41
  )
42
 
43
  # -----------------------------------------------------------------------------
44
+ # FastAPI 定義
45
  # -----------------------------------------------------------------------------
46
+ app = FastAPI(title="Gemma3-12B-IT Q4_0 GGUF API")
47
 
48
  class GenerationRequest(BaseModel):
49
  prompt: str
 
55
  async def generate(req: GenerationRequest):
56
  if not req.prompt:
57
  raise HTTPException(status_code=400, detail="`prompt` は必須です。")
58
+ # llama.cpp の generate を呼び出し
59
+ text = llm.generate(
60
  req.prompt,
 
 
 
 
 
 
 
 
61
  top_p=req.top_p,
62
+ temp=req.temperature,
63
+ n_predict=req.max_new_tokens,
64
+ repeat_last_n=64,
65
+ repeat_penalty=1.1
66
  )
 
67
  return {"generated_text": text}
68
 
69
  # -----------------------------------------------------------------------------