File size: 2,955 Bytes
7f80d8c
dea3a07
2fc7e1b
7f80d8c
115a37b
7f80d8c
 
 
 
 
 
 
9736832
7f80d8c
 
9736832
7f80d8c
9736832
 
 
 
 
e604a26
7f80d8c
9736832
7f80d8c
9736832
7f80d8c
a9bf179
 
 
 
9736832
7f80d8c
9736832
 
 
 
 
 
 
 
 
7f80d8c
 
 
9736832
7f80d8c
9736832
7f80d8c
 
 
 
 
 
115a37b
 
7f80d8c
 
9736832
 
7f80d8c
 
 
 
9736832
 
 
7f80d8c
 
 
 
 
 
 
9736832
7f80d8c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# app.py
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------------------------------------------------------
# 設定
# -----------------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE   = torch.device("cpu")  # 無料枠は CPU のみ

# -----------------------------------------------------------------------------
# トークナイザーのロード
# -----------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    trust_remote_code=True
)

# -----------------------------------------------------------------------------
# モデルのロード+低メモリモード
# -----------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    trust_remote_code=True,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True
).to(DEVICE)

# -----------------------------------------------------------------------------
# 動的量子化の適用
# -----------------------------------------------------------------------------
# - {torch.nn.Linear} を INT8 化
# - dtype=torch.qint8 で重みのみ量子化
model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},
    dtype=torch.qint8
)

# -----------------------------------------------------------------------------
# FastAPI サーバー定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization")

class GenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 128
    temperature: float = 0.8
    top_p: float = 0.95

@app.post("/generate")
async def generate(req: GenerationRequest):
    if not req.prompt:
        raise HTTPException(status_code=400, detail="`prompt` は必須です。")
    # トークナイズして推論
    inputs = tokenizer(
        req.prompt,
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(DEVICE)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=req.max_new_tokens,
        temperature=req.temperature,
        top_p=req.top_p,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return {"generated_text": text}

# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")