Spaces:
Running
Running
File size: 2,955 Bytes
7f80d8c dea3a07 2fc7e1b 7f80d8c 115a37b 7f80d8c 9736832 7f80d8c 9736832 7f80d8c 9736832 e604a26 7f80d8c 9736832 7f80d8c 9736832 7f80d8c a9bf179 9736832 7f80d8c 9736832 7f80d8c 9736832 7f80d8c 9736832 7f80d8c 115a37b 7f80d8c 9736832 7f80d8c 9736832 7f80d8c 9736832 7f80d8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# app.py
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# 設定
# -----------------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = torch.device("cpu") # 無料枠は CPU のみ
# -----------------------------------------------------------------------------
# トークナイザーのロード
# -----------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
trust_remote_code=True
)
# -----------------------------------------------------------------------------
# モデルのロード+低メモリモード
# -----------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(DEVICE)
# -----------------------------------------------------------------------------
# 動的量子化の適用
# -----------------------------------------------------------------------------
# - {torch.nn.Linear} を INT8 化
# - dtype=torch.qint8 で重みのみ量子化
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# -----------------------------------------------------------------------------
# FastAPI サーバー定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization")
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
temperature: float = 0.8
top_p: float = 0.95
@app.post("/generate")
async def generate(req: GenerationRequest):
if not req.prompt:
raise HTTPException(status_code=400, detail="`prompt` は必須です。")
# トークナイズして推論
inputs = tokenizer(
req.prompt,
return_tensors="pt",
truncation=True,
padding=True
).to(DEVICE)
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": text}
# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|