MATRIX / app.py
laserbeam2045
fix
9736832
raw
history blame
2.96 kB
# app.py
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------------------------------------------------------
# 設定
# -----------------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = torch.device("cpu") # 無料枠は CPU のみ
# -----------------------------------------------------------------------------
# トークナイザーのロード
# -----------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
trust_remote_code=True
)
# -----------------------------------------------------------------------------
# モデルのロード+低メモリモード
# -----------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(DEVICE)
# -----------------------------------------------------------------------------
# 動的量子化の適用
# -----------------------------------------------------------------------------
# - {torch.nn.Linear} を INT8 化
# - dtype=torch.qint8 で重みのみ量子化
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# -----------------------------------------------------------------------------
# FastAPI サーバー定義
# -----------------------------------------------------------------------------
app = FastAPI(title="Gemma3-4B-IT with Dynamic Quantization")
class GenerationRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
temperature: float = 0.8
top_p: float = 0.95
@app.post("/generate")
async def generate(req: GenerationRequest):
if not req.prompt:
raise HTTPException(status_code=400, detail="`prompt` は必須です。")
# トークナイズして推論
inputs = tokenizer(
req.prompt,
return_tensors="pt",
truncation=True,
padding=True
).to(DEVICE)
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": text}
# -----------------------------------------------------------------------------
# ローカル起動用
# -----------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")